Skip to content

Commit d59e426

Browse files
authored
[BACKEND] Also support rank reducing case in updateEncodingForShape (triton-lang#6375)
Fix the helper to support both cases when we are rank reducing or expanding
1 parent 17966f4 commit d59e426

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,15 @@ updateEncodingForShape(Operation *op, gpu::SharedEncodingTrait encoding,
8282

8383
auto rank = tensorType.getRank();
8484
auto oldOrder = swizEnc.getOrder();
85-
assert(oldOrder.size() <= rank);
8685
SmallVector<unsigned> order;
8786
for (int i = 0; i + oldOrder.size() < rank; ++i)
8887
order.push_back(rank - i - 1);
89-
order.append(oldOrder.begin(), oldOrder.end());
88+
for (int i = 0; i < oldOrder.size(); ++i) {
89+
// If it is a rank-reducing load, we need to drop the last dimensions.
90+
if (oldOrder[i] >= rank)
91+
continue;
92+
order.push_back(oldOrder[i]);
93+
}
9094
auto newCtaEnc = updateCTALayoutForShape(ctaLayout, tensorType.getShape());
9195
return gpu::SwizzledSharedEncodingAttr::get(
9296
ctx, swizEnc.getVec(), swizEnc.getPerPhase(), swizEnc.getMaxPhase(),

test/TritonNvidiaGPU/tma_lowering.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,19 @@ tt.func @tma_scatter(%arg0: !tt.tensordesc<tensor<1x128xbf16, #nvmma_128>>, %arg
9090
}
9191

9292
}
93+
94+
// -----
95+
96+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
97+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0]}>
98+
// CHECK: #[[$SHARED:.+]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
99+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
100+
// CHECK-LABLE: @rank_reducing_load
101+
tt.func public @rank_reducing_load(%arg0: !tt.tensordesc<tensor<1x256x32xf32, #shared>>) -> tensor<256x32xf32, #blocked> {
102+
%c32_i32 = arith.constant 32 : i32
103+
// CHECK: %[[A:.+]] = ttg.local_alloc : () -> !ttg.memdesc<256x32xf32, #[[$SHARED]], #smem, mutable>
104+
// CHECK: tng.async_tma_copy_global_to_local %{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}] %[[A]],
105+
%l = tt.descriptor_load %arg0[%c32_i32, %c32_i32, %c32_i32] : !tt.tensordesc<tensor<1x256x32xf32, #shared>> -> tensor<256x32xf32, #blocked>
106+
tt.return %l : tensor<256x32xf32, #blocked>
107+
}
108+
}

0 commit comments

Comments
 (0)