Skip to content

Commit 01fb036

Browse files
authored
[Pipeliner] Handle masking for atomic_rmw (#5231)
This commit is to support atomic_rmw in the function predicateOp to mask operations during scheduling.
1 parent 8d42d21 commit 01fb036

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
8080
storeOp.getMaskMutable().assign(mask);
8181
return op;
8282
}
83+
if (auto atomicRMWOp = dyn_cast<tt::AtomicRMWOp>(op)) {
84+
rewriter.setInsertionPoint(atomicRMWOp);
85+
Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(),
86+
atomicRMWOp.getMask(), pred);
87+
atomicRMWOp.getMaskMutable().assign(mask);
88+
return op;
89+
}
8390

8491
assert("don't know how to predicate this op" && false);
8592
return op;

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,71 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
266266
tt.return
267267
}
268268
}
269+
270+
// -----
271+
272+
// Check that the stream pipeliner updates atomic op in the k-loop correctly
273+
// CHECK-LABEL: _triton_gemm_kernel_atomic_rmw
274+
// CHECK: scf.for
275+
// CHECK: tt.atomic_rmw fadd, acq_rel, gpu
276+
// CHECK: tt.dot
277+
// CHECK: scf.yield
278+
279+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
280+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
281+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
282+
tt.func public @_triton_gemm_kernel_atomic_rmw(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc(unknown), %arg3: i32 {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} {
283+
%cst = arith.constant dense<32> : tensor<32x32xi32, #blocked>
284+
%c0_i32 = arith.constant 0 : i32
285+
%c1_i32 = arith.constant 1 : i32
286+
%c31_i32 = arith.constant 31 : i32
287+
%c32_i32 = arith.constant 32 : i32
288+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
289+
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
290+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
291+
%2 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked>
292+
%3 = arith.muli %1, %2 : tensor<32x1xi32, #blocked>
293+
%4 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
294+
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
295+
%6 = tt.broadcast %3 : tensor<32x1xi32, #blocked> -> tensor<32x32xi32, #blocked>
296+
%7 = tt.broadcast %5 : tensor<1x32xi32, #blocked> -> tensor<32x32xi32, #blocked>
297+
%8 = arith.addi %6, %7 : tensor<32x32xi32, #blocked>
298+
%9 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
299+
%10 = tt.addptr %9, %8 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
300+
%11 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
301+
%12 = tt.addptr %11, %8 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
302+
%13 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<32x1x!tt.ptr<f16>, #blocked>
303+
%14 = tt.addptr %13, %3 : tensor<32x1x!tt.ptr<f16>, #blocked>, tensor<32x1xi32, #blocked>
304+
%15 = tt.broadcast %14 : tensor<32x1x!tt.ptr<f16>, #blocked> -> tensor<32x32x!tt.ptr<f16>, #blocked>
305+
%16 = tt.addptr %15, %7 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
306+
%17 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked>
307+
%18 = arith.cmpi slt, %1, %17 : tensor<32x1xi32, #blocked>
308+
%19 = tt.splat %arg3 : i32 -> tensor<1x32xi32, #blocked>
309+
%20 = arith.cmpi slt, %5, %19 : tensor<1x32xi32, #blocked>
310+
%21 = tt.broadcast %18 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
311+
%22 = tt.broadcast %20 : tensor<1x32xi1, #blocked> -> tensor<32x32xi1, #blocked>
312+
%23 = arith.andi %21, %22 : tensor<32x32xi1, #blocked>
313+
%24 = arith.addi %arg3, %c31_i32 : i32
314+
%25 = arith.divsi %24, %c32_i32 : i32
315+
%26 = arith.muli %arg4, %c32_i32 : i32
316+
%27 = tt.splat %26 : i32 -> tensor<32x32xi32, #blocked>
317+
%28:3 = scf.for %arg5 = %c0_i32 to %25 step %c1_i32 iter_args(%arg6 = %cst_0, %arg7 = %10, %arg8 = %12) -> (tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32x!tt.ptr<f16>, #blocked>) : i32 {
318+
%32 = tt.load %arg7 : tensor<32x32x!tt.ptr<f16>, #blocked>
319+
%33 = tt.load %arg8 : tensor<32x32x!tt.ptr<f16>, #blocked>
320+
%34 = triton_gpu.convert_layout %32 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
321+
%35 = triton_gpu.convert_layout %33 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
322+
%36 = tt.dot %34, %35, %arg6 : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
323+
%37 = tt.addptr %arg7, %cst : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
324+
%38 = tt.addptr %arg8, %27 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
325+
%39 = arith.truncf %36 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma>
326+
%40 = triton_gpu.convert_layout %39 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #blocked>
327+
%41 = tt.atomic_rmw fadd, acq_rel, gpu, %16, %40, %23 : (tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked>
328+
scf.yield %36, %37, %38 : tensor<32x32xf32, #mma>, tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32x!tt.ptr<f16>, #blocked>
329+
}
330+
%29 = arith.truncf %28#0 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma>
331+
%30 = triton_gpu.convert_layout %16 : tensor<32x32x!tt.ptr<f16>, #blocked> -> tensor<32x32x!tt.ptr<f16>, #mma>
332+
%31 = triton_gpu.convert_layout %23 : tensor<32x32xi1, #blocked> -> tensor<32x32xi1, #mma>
333+
tt.store %30, %29, %31 : tensor<32x32x!tt.ptr<f16>, #mma>
334+
tt.return
335+
}
336+
}

0 commit comments

Comments
 (0)