Skip to content

Commit e163113

Browse files
authored
Fix division by zero in ReduceDataDuplicationPass (#6849)
For a small contiguous shape dimension, we could get a division by zero when converting to MMA layout.
1 parent a2ce747 commit e163113

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
322322

323323
// ---- begin Ampere & Hopper ----
324324
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
325-
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
325+
int perPhase = 128 / (std::max<int>(1, shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()));
326326
perPhase = std::max<int>(perPhase, 1);
327327
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
328328
int vecWidth = 32 / typeWidthInBit;

test/TritonGPU/reduce-data-duplication.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-
2929

3030
// -----
3131

32+
// CHECK: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 32, perPhase = 128, maxPhase = 1, order = [1, 0]}>
33+
// CHECK-LABEL: handles_small_contiguous_dim
34+
// CHECK: %{{.*}} = ttg.local_alloc %{{.*}} : (tensor<32x1xf16, #{{.*}}>) -> !ttg.memdesc<32x1xf16, #[[$SHARED]], #smem>
35+
36+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
37+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
38+
module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
39+
tt.func @handles_small_contiguous_dim(%arg0: tensor<32x1xf16, #blocked>) {
40+
%0 = ttg.convert_layout %arg0 : tensor<32x1xf16, #blocked> -> tensor<32x1xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
41+
tt.return
42+
}
43+
}
44+
45+
// -----
46+
3247
// CHECK-LABEL: conversion_shortcut_blocked_dotop_warp64
3348
// CHECK-NOT: ttg.local_alloc
3449
// CHECK: ttg.convert_layout

0 commit comments

Comments
 (0)