File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -3206,10 +3206,6 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
32063206 if (!layout)
32073207 return " " ;
32083208
3209- unsigned threadsPerWarp = getWarpSize (layout);
3210- unsigned numWarpsPerCTA = getNumWarpsPerCTA (layout);
3211- unsigned numBlocks = getNumCTAs (layout);
3212- int numElementsPerThreads = getTotalElemsPerThread (tensorType);
32133209 StringAttr kRegister = StringAttr::get (tensorType.getContext (), " register" );
32143210 StringAttr kLane = StringAttr::get (tensorType.getContext (), " lane" );
32153211 StringAttr kWarp = StringAttr::get (tensorType.getContext (), " warp" );
@@ -3222,6 +3218,10 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType,
32223218 int64_t tensorSize = product (tensorType.getShape ());
32233219 std::vector<std::string> elementMapping (tensorSize);
32243220 std::vector<std::string> threadMapping;
3221+ unsigned threadsPerWarp = ll->getInDimSize (kLane );
3222+ unsigned numWarpsPerCTA = ll->getInDimSize (kWarp );
3223+ unsigned numBlocks = ll->getInDimSize (kBlock );
3224+ int numElementsPerThreads = ll->getInDimSize (kRegister );
32253225 for (int blockId = 0 ; blockId < numBlocks; ++blockId) {
32263226 for (int warpId = 0 ; warpId < numWarpsPerCTA; warpId++) {
32273227 for (int tid = 0 ; tid < threadsPerWarp; ++tid) {
You can’t perform that action at this time.
0 commit comments