diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 8dc7829c38..7cde755873 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -3206,10 +3206,6 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, if (!layout) return ""; - unsigned threadsPerWarp = getWarpSize(layout); - unsigned numWarpsPerCTA = getNumWarpsPerCTA(layout); - unsigned numBlocks = getNumCTAs(layout); - int numElementsPerThreads = getTotalElemsPerThread(tensorType); StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register"); StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane"); StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp"); @@ -3222,6 +3218,10 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, int64_t tensorSize = product(tensorType.getShape()); std::vector elementMapping(tensorSize); std::vector threadMapping; + unsigned threadsPerWarp = ll->getInDimSize(kLane); + unsigned numWarpsPerCTA = ll->getInDimSize(kWarp); + unsigned numBlocks = ll->getInDimSize(kBlock); + int numElementsPerThreads = ll->getInDimSize(kRegister); for (int blockId = 0; blockId < numBlocks; ++blockId) { for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { for (int tid = 0; tid < threadsPerWarp; ++tid) {