@@ -266,3 +266,71 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
266266 tt.return
267267 }
268268}
269+
270+ // -----
271+
272+ // Check that the stream pipeliner updates atomic op in the k-loop correctly
273+ // CHECK-LABEL: _triton_gemm_kernel_atomic_rmw
274+ // CHECK: scf.for
275+ // CHECK: tt.atomic_rmw fadd, acq_rel, gpu
276+ // CHECK: tt.dot
277+ // CHECK: scf.yield
278+
279+ #blocked = #triton_gpu.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
280+ #mma = #triton_gpu.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
281+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 , triton_gpu.target = " hip:gfx942" , " triton_gpu.threads-per-warp" = 64 : i32 } {
282+ tt.func public @_triton_gemm_kernel_atomic_rmw (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 } loc (unknown ), %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 } loc (unknown ), %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 } loc (unknown ), %arg3: i32 {tt.divisibility = 16 : i32 } loc (unknown ), %arg4: i32 {tt.divisibility = 16 : i32 } loc (unknown )) attributes {noinline = false } {
283+ %cst = arith.constant dense <32 > : tensor <32 x32 xi32 , #blocked >
284+ %c0_i32 = arith.constant 0 : i32
285+ %c1_i32 = arith.constant 1 : i32
286+ %c31_i32 = arith.constant 31 : i32
287+ %c32_i32 = arith.constant 32 : i32
288+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #mma >
289+ %0 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #triton_gpu.slice <{dim = 1 , parent = #blocked }>>
290+ %1 = tt.expand_dims %0 {axis = 1 : i32 } : tensor <32 xi32 , #triton_gpu.slice <{dim = 1 , parent = #blocked }>> -> tensor <32 x1 xi32 , #blocked >
291+ %2 = tt.splat %arg4 : i32 -> tensor <32 x1 xi32 , #blocked >
292+ %3 = arith.muli %1 , %2 : tensor <32 x1 xi32 , #blocked >
293+ %4 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #triton_gpu.slice <{dim = 0 , parent = #blocked }>>
294+ %5 = tt.expand_dims %4 {axis = 0 : i32 } : tensor <32 xi32 , #triton_gpu.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x32 xi32 , #blocked >
295+ %6 = tt.broadcast %3 : tensor <32 x1 xi32 , #blocked > -> tensor <32 x32 xi32 , #blocked >
296+ %7 = tt.broadcast %5 : tensor <1 x32 xi32 , #blocked > -> tensor <32 x32 xi32 , #blocked >
297+ %8 = arith.addi %6 , %7 : tensor <32 x32 xi32 , #blocked >
298+ %9 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <32 x32 x!tt.ptr <f16 >, #blocked >
299+ %10 = tt.addptr %9 , %8 : tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 xi32 , #blocked >
300+ %11 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <32 x32 x!tt.ptr <f16 >, #blocked >
301+ %12 = tt.addptr %11 , %8 : tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 xi32 , #blocked >
302+ %13 = tt.splat %arg2 : !tt.ptr <f16 > -> tensor <32 x1 x!tt.ptr <f16 >, #blocked >
303+ %14 = tt.addptr %13 , %3 : tensor <32 x1 x!tt.ptr <f16 >, #blocked >, tensor <32 x1 xi32 , #blocked >
304+ %15 = tt.broadcast %14 : tensor <32 x1 x!tt.ptr <f16 >, #blocked > -> tensor <32 x32 x!tt.ptr <f16 >, #blocked >
305+ %16 = tt.addptr %15 , %7 : tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 xi32 , #blocked >
306+ %17 = tt.splat %arg3 : i32 -> tensor <32 x1 xi32 , #blocked >
307+ %18 = arith.cmpi slt , %1 , %17 : tensor <32 x1 xi32 , #blocked >
308+ %19 = tt.splat %arg3 : i32 -> tensor <1 x32 xi32 , #blocked >
309+ %20 = arith.cmpi slt , %5 , %19 : tensor <1 x32 xi32 , #blocked >
310+ %21 = tt.broadcast %18 : tensor <32 x1 xi1 , #blocked > -> tensor <32 x32 xi1 , #blocked >
311+ %22 = tt.broadcast %20 : tensor <1 x32 xi1 , #blocked > -> tensor <32 x32 xi1 , #blocked >
312+ %23 = arith.andi %21 , %22 : tensor <32 x32 xi1 , #blocked >
313+ %24 = arith.addi %arg3 , %c31_i32 : i32
314+ %25 = arith.divsi %24 , %c32_i32 : i32
315+ %26 = arith.muli %arg4 , %c32_i32 : i32
316+ %27 = tt.splat %26 : i32 -> tensor <32 x32 xi32 , #blocked >
317+ %28:3 = scf.for %arg5 = %c0_i32 to %25 step %c1_i32 iter_args (%arg6 = %cst_0 , %arg7 = %10 , %arg8 = %12 ) -> (tensor <32 x32 xf32 , #mma >, tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 x!tt.ptr <f16 >, #blocked >) : i32 {
318+ %32 = tt.load %arg7 : tensor <32 x32 x!tt.ptr <f16 >, #blocked >
319+ %33 = tt.load %arg8 : tensor <32 x32 x!tt.ptr <f16 >, #blocked >
320+ %34 = triton_gpu.convert_layout %32 : tensor <32 x32 xf16 , #blocked > -> tensor <32 x32 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
321+ %35 = triton_gpu.convert_layout %33 : tensor <32 x32 xf16 , #blocked > -> tensor <32 x32 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
322+ %36 = tt.dot %34 , %35 , %arg6 : tensor <32 x32 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <32 x32 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <32 x32 xf32 , #mma >
323+ %37 = tt.addptr %arg7 , %cst : tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 xi32 , #blocked >
324+ %38 = tt.addptr %arg8 , %27 : tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 xi32 , #blocked >
325+ %39 = arith.truncf %36 : tensor <32 x32 xf32 , #mma > to tensor <32 x32 xf16 , #mma >
326+ %40 = triton_gpu.convert_layout %39 : tensor <32 x32 xf16 , #mma > -> tensor <32 x32 xf16 , #blocked >
327+ %41 = tt.atomic_rmw fadd , acq_rel , gpu , %16 , %40 , %23 : (tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 xf16 , #blocked >, tensor <32 x32 xi1 , #blocked >) -> tensor <32 x32 xf16 , #blocked >
328+ scf.yield %36 , %37 , %38 : tensor <32 x32 xf32 , #mma >, tensor <32 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x32 x!tt.ptr <f16 >, #blocked >
329+ }
330+ %29 = arith.truncf %28#0 : tensor <32 x32 xf32 , #mma > to tensor <32 x32 xf16 , #mma >
331+ %30 = triton_gpu.convert_layout %16 : tensor <32 x32 x!tt.ptr <f16 >, #blocked > -> tensor <32 x32 x!tt.ptr <f16 >, #mma >
332+ %31 = triton_gpu.convert_layout %23 : tensor <32 x32 xi1 , #blocked > -> tensor <32 x32 xi1 , #mma >
333+ tt.store %30 , %29 , %31 : tensor <32 x32 x!tt.ptr <f16 >, #mma >
334+ tt.return
335+ }
336+ }
0 commit comments