Skip to content

Commit 76c054e

Browse files
Merge commit '251ec88dea0f0c7ab45592c6287fda4b8585440f'
2 parents 016ee00 + 251ec88 commit 76c054e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3206,10 +3206,6 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
32063206
if (!layout)
32073207
return "";
32083208

3209-
unsigned threadsPerWarp = getWarpSize(layout);
3210-
unsigned numWarpsPerCTA = getNumWarpsPerCTA(layout);
3211-
unsigned numBlocks = getNumCTAs(layout);
3212-
int numElementsPerThreads = getTotalElemsPerThread(tensorType);
32133209
StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register");
32143210
StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane");
32153211
StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp");
@@ -3222,6 +3218,10 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
32223218
int64_t tensorSize = product(tensorType.getShape());
32233219
std::vector<std::string> elementMapping(tensorSize);
32243220
std::vector<std::string> threadMapping;
3221+
unsigned threadsPerWarp = ll->getInDimSize(kLane);
3222+
unsigned numWarpsPerCTA = ll->getInDimSize(kWarp);
3223+
unsigned numBlocks = ll->getInDimSize(kBlock);
3224+
int numElementsPerThreads = ll->getInDimSize(kRegister);
32253225
for (int blockId = 0; blockId < numBlocks; ++blockId) {
32263226
for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) {
32273227
for (int tid = 0; tid < threadsPerWarp; ++tid) {

0 commit comments

Comments
 (0)