Skip to content

Commit 6dd7d6a

Browse files
authored
[BACKEND] Make sure to propagate the new type when updating memdesc (#7134)
Since the new memdesc will be mutable we need to propagate the type.
1 parent 4e0281e commit 6dd7d6a

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ Operation *hoistBufferOutOfLoop(scf::ForOp forOp, Operation *op,
829829
newStore = builder.create<ttg::LocalStoreOp>(op->getOperand(0),
830830
localAlloc.getResult());
831831
}
832-
op->replaceAllUsesWith(newAlloc);
832+
replaceUsesAndPropagateType(builder, op, newAlloc->getResult(0));
833833
op->erase();
834834
return newStore;
835835
}

test/TritonGPU/loop-pipeline-blackwell.mlir

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
9595
}
9696

9797
// -----
98-
9998
// 4 warps
10099
// matmul: 128x32 @ 32x128 -> 128x128
101100
#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -107,6 +106,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
107106
#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
108107
#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
109108
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
109+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
110110
#smem = #ttg.shared_memory
111111
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
112112
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
@@ -119,6 +119,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.targ
119119
// CHECK: tt.fp_to_fp
120120
// CHECK: ttng.wait_barrier
121121
// CHECK: ttg.local_store
122+
// CHECK: ttg.memdesc_trans
122123
// CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}
123124
// CHECK: ttg.async_copy_global_to_local
124125
%a_ptr_splat = tt.splat %A : !tt.ptr<f8E4M3FN> -> tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>
@@ -127,37 +128,38 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, ttg.targ
127128
%a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL>
128129
%a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>, tensor<128x32xi32, #AL>
129130

130-
%b_ptr_splat = tt.splat %B : !tt.ptr<f8E4M3FN> -> tensor<32x128x!tt.ptr<f8E4M3FN>, #BL>
131-
%b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0>
132-
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL>
133-
%b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL>
134-
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f8E4M3FN>, #BL>, tensor<32x128xi32, #BL>
131+
%b_ptr_splat = tt.splat %B : !tt.ptr<f8E4M3FN> -> tensor<128x32x!tt.ptr<f8E4M3FN>, #BL>
132+
%b_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #BLs0>
133+
%b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<32xi32, #BLs0> -> tensor<1x32xi32, #BL>
134+
%b_offs = tt.broadcast %b_tmp1 : tensor<1x32xi32, #BL> -> tensor<128x32xi32, #BL>
135+
%b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<128x32x!tt.ptr<f8E4M3FN>, #BL>, tensor<128x32xi32, #BL>
135136

136137
%true = arith.constant true
137-
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
138-
%b_other = arith.constant dense<0.00e+00> : tensor<32x128xf8E4M3FN, #BL>
138+
%b_mask = arith.constant dense<true> : tensor<128x32xi1, #BL>
139+
%b_other = arith.constant dense<0.00e+00> : tensor<128x32xf8E4M3FN, #BL>
139140
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
140141

141142
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
142-
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
143+
%b_off = arith.constant dense<4> : tensor<128x32xi32, #BL>
143144

144-
%loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>, tensor<32x128x!tt.ptr<f8E4M3FN>, #BL>, tensor<128x128xf32, #C>) {
145+
%loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>, tensor<128x32x!tt.ptr<f8E4M3FN>, #BL>, tensor<128x128xf32, #C>) {
145146
%a___ = tt.load %a_ptr : tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>
146147
%a__ = tt.fp_to_fp %a___ : tensor<128x32xf8E4M3FN, #AL> -> tensor<128x32xf16, #AL>
147148
%a_ = ttg.convert_layout %a__ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A>
148-
%b___ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f8E4M3FN>, #BL>
149-
%b__ = tt.fp_to_fp %b___ : tensor<32x128xf8E4M3FN, #BL> -> tensor<32x128xf16, #BL>
150-
%b_ = ttg.convert_layout %b__ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B>
149+
%b___ = tt.load %b_ptr, %b_mask, %b_other : tensor<128x32x!tt.ptr<f8E4M3FN>, #BL>
150+
%b__ = tt.fp_to_fp %b___ : tensor<128x32xf8E4M3FN, #BL> -> tensor<128x32xf16, #BL>
151+
%b_ = ttg.convert_layout %b__ : tensor<128x32xf16, #BL> -> tensor<128x32xf16, #B>
151152

152153
%a = ttg.local_alloc %a_ {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> !ttg.memdesc<128x32xf16, #shared, #smem>
153-
%b = ttg.local_alloc %b_ {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<32x128xf16, #B>) -> !ttg.memdesc<32x128xf16, #shared, #smem>
154+
%b = ttg.local_alloc %b_ {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #B>) -> !ttg.memdesc<128x32xf16, #shared, #smem>
155+
%bt = ttg.memdesc_trans %b {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array<i32: 1, 0>} : !ttg.memdesc<128x32xf16, #shared, #smem> -> !ttg.memdesc<32x128xf16, #shared1, #smem>
154156
%acc_tm, %acc_tok = ttng.tmem_alloc %prev_c : (tensor<128x128xf32, #C>) -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
155-
%mma_tok = ttng.tc_gen5_mma %a, %b, %acc_tm[%acc_tok], %true, %true : !ttg.memdesc<128x32xf16, #shared, #smem>, !ttg.memdesc<32x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
157+
%mma_tok = ttng.tc_gen5_mma %a, %bt, %acc_tm[%acc_tok], %true, %true : !ttg.memdesc<128x32xf16, #shared, #smem>, !ttg.memdesc<32x128xf16, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
156158
%c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #C>
157159

158160
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>, tensor<128x32xi32, #AL>
159-
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f8E4M3FN>, #BL>, tensor<32x128xi32, #BL>
160-
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>, tensor<32x128x!tt.ptr<f8E4M3FN>, #BL>, tensor<128x128xf32, #C>
161+
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<128x32x!tt.ptr<f8E4M3FN>, #BL>, tensor<128x32xi32, #BL>
162+
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f8E4M3FN>, #AL>, tensor<128x32x!tt.ptr<f8E4M3FN>, #BL>, tensor<128x128xf32, #C>
161163
}
162164
tt.return %loop#2: tensor<128x128xf32, #C>
163165
}

0 commit comments

Comments
 (0)