1- // RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s
1+ // RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,SYNC
2+ // RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2 use_async_copy=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,ASYNC
23
34#blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
45#blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
78#shared1 = #ttg.swizzled_shared <{vec = 8 , perPhase = 1 , maxPhase = 8 , order = [1 , 0 ]}>
89#smem = #ttg.shared_memory
910module attributes {" ttg.target" = " hip:gfx942" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
10- // CHECK -LABEL: tt.func @load_two_users
11+ // COMMON -LABEL: tt.func @load_two_users
1112 tt.func @load_two_users (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }) -> (tensor <128 x16 xf32 , #mma >, tensor <128 x64 xf32 , #mma >) {
1213 %cst = arith.constant dense <0 > : tensor <1 x16 xi32 , #blocked >
1314 %cst_0 = arith.constant dense <0 > : tensor <128 x1 xi32 , #blocked1 >
@@ -34,13 +35,13 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
3435 %14 = tt.broadcast %11 : tensor <1 x16 x!tt.ptr <f16 >, #blocked > -> tensor <64 x16 x!tt.ptr <f16 >, #blocked >
3536 %15 = tt.broadcast %13 : tensor <64 x1 xi32 , #blocked > -> tensor <64 x16 xi32 , #blocked >
3637 %16 = tt.addptr %14 , %15 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >, tensor <64 x16 xi32 , #blocked >
37- // CHECK : ttg.local_store
38- // CHECK : scf.for
39- // CHECK : tt.load
40- // CHECK : tt.dot
41- // CHECK : tt.dot
42- // CHECK : ttg.local_store
43- // CHECK : scf.yield
38+ // COMMON : ttg.local_store
39+ // COMMON : scf.for
40+ // COMMON : tt.load
41+ // COMMON : tt.dot
42+ // COMMON : tt.dot
43+ // COMMON : ttg.local_store
44+ // COMMON : scf.yield
4445 %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args (%arg3 = %cst_1 , %arg4 = %cst_2 ) -> (tensor <128 x16 xf32 , #mma >, tensor <128 x64 xf32 , #mma >) : i32 {
4546 %18 = tt.load %16 : tensor <64 x16 x!tt.ptr <f16 >, #blocked >
4647 %19 = ttg.convert_layout %9 : tensor <128 x64 xf16 , #blocked1 > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
@@ -60,8 +61,8 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
6061
6162// -----
6263
63- // CHECK -LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de
64- // CHECK -NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1>
64+ // COMMON -LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de
65+ // COMMON -NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1>
6566
6667#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [64 , 1 ], warpsPerCTA = [2 , 2 ], order = [0 , 1 ]}>
6768#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [64 , 1 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
@@ -166,13 +167,13 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
166167
167168// Disable pipelining for loops that contain barrier.
168169// Barriers are problematic since they are not chained to any other operation.
169- // CHECK -LABEL: tt.func public @add_barrier_kernel
170- // CHECK : scf.for
171- // CHECK : tt.load
172- // CHECK : gpu.barrier
173- // CHECK : tt.store
174- // CHECK -NOT: gpu.barrier
175- // CHECK : tt.return
170+ // COMMON -LABEL: tt.func public @add_barrier_kernel
171+ // COMMON : scf.for
172+ // COMMON : tt.load
173+ // COMMON : gpu.barrier
174+ // COMMON : tt.store
175+ // COMMON -NOT: gpu.barrier
176+ // COMMON : tt.return
176177
177178#blocked = #ttg.blocked <{sizePerThread = [4 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
178179module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
@@ -203,11 +204,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
203204
204205// -----
205206
206- // CHECK -NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1]
207- // CHECK : #ttg.swizzled_shared<{{.*}} order = [2, 1, 0]
208- // CHECK -NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1]
207+ // COMMON -NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1]
208+ // COMMON : #ttg.swizzled_shared<{{.*}} order = [2, 1, 0]
209+ // COMMON -NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1]
209210
210- // CHECK -LABEL: tt.func public @slowest_dim_is_batch
211+ // COMMON -LABEL: tt.func public @slowest_dim_is_batch
211212#blocked = #ttg.blocked <{sizePerThread = [1 , 1 , 2 ], threadsPerWarp = [4 , 1 , 16 ], warpsPerCTA = [4 , 1 , 1 ], order = [2 , 1 , 0 ]}>
212213#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 , 8 ], threadsPerWarp = [16 , 1 , 4 ], warpsPerCTA = [4 , 1 , 1 ], order = [2 , 0 , 1 ]}>
213214#blocked2 = #ttg.blocked <{sizePerThread = [1 , 2 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
@@ -239,9 +240,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
239240// -----
240241
241242// Check that the stream pipeliner updates the resulting memory layout of transpose ops to mutable if immutable local buffers are replaced
242- // CHECK -LABEL: loop_with_dot_and_transpose
243- // CHECK : ttg.local_alloc {{.*}}, mutable>
244- // CHECK : ttg.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable>
243+ // COMMON -LABEL: loop_with_dot_and_transpose
244+ // COMMON : ttg.local_alloc {{.*}}, mutable>
245+ // COMMON : ttg.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable>
245246
246247#blocked = #ttg.blocked <{sizePerThread = [2 , 2 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
247248#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
@@ -270,11 +271,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
270271// -----
271272
272273// Check that the stream pipeliner updates atomic op in the k-loop correctly
273- // CHECK -LABEL: _triton_gemm_kernel_atomic_rmw
274- // CHECK : scf.for
275- // CHECK : tt.atomic_rmw fadd, acq_rel, gpu
276- // CHECK : tt.dot
277- // CHECK : scf.yield
274+ // COMMON -LABEL: _triton_gemm_kernel_atomic_rmw
275+ // COMMON : scf.for
276+ // COMMON : tt.atomic_rmw fadd, acq_rel, gpu
277+ // COMMON : tt.dot
278+ // COMMON : scf.yield
278279
279280#blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
280281#mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
@@ -338,25 +339,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
338339// -----
339340
340341// Check that we can pipeline scaled dot with linear layout
341- // CHECK -LABEL: mxfp8_mxfp4_matmul
342+ // COMMON -LABEL: mxfp8_mxfp4_matmul
342343
343344// Prologue
344- // CHECK -COUNT-3: ttg.local_alloc
345- // CHECK -COUNT-3: tt.load
346- // CHECK -COUNT-3: ttg.local_store
345+ // COMMON -COUNT-3: ttg.local_alloc
346+ // COMMON -COUNT-3: tt.load
347+ // COMMON -COUNT-3: ttg.local_store
347348
348349// Main loop
349- // CHECK : scf.for
350- // CHECK -COUNT-3: ttg.local_load
351- // CHECK : tt.dot_scaled
352- // CHECK : scf.yield
350+ // COMMON : scf.for
351+ // COMMON -COUNT-3: ttg.local_load
352+ // COMMON : tt.dot_scaled
353+ // COMMON : scf.yield
353354
354355// Epilogue
355- // CHECK -COUNT-3: ttg.local_load
356- // CHECK : scf.if
357- // CHECK : tt.dot_scaled
358- // CHECK -COUNT-2: scf.yield
359- // CHECK -COUNT-3: ttg.local_dealloc
356+ // COMMON -COUNT-3: ttg.local_load
357+ // COMMON : scf.if
358+ // COMMON : tt.dot_scaled
359+ // COMMON -COUNT-2: scf.yield
360+ // COMMON -COUNT-3: ttg.local_dealloc
360361
361362#blocked = #ttg.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [4 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
362363#blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [64 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
@@ -464,3 +465,81 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
464465 tt.return
465466 }
466467}
468+
469+ // -----
470+
471+ // Check that we can pipeline a simple matmul kernel
472+ // Note: Currently AsyncCopy is only used for the second operand because we do not support (actual) swizzled shared encodings
473+
474+ // COMMON-LABEL: simple_matmul_kernel
475+
476+ // Prologue
477+ // COMMON-COUNT-2: ttg.local_alloc
478+ // SYNC-COUNT-2: tt.load
479+ // SYNC-COUNT-2: ttg.local_store
480+ //
481+ // ASYNC: tt.load
482+ // ASYNC: ttg.local_store
483+ // ASYNC: ttg.async_copy_global_to_local
484+
485+ // Main loop
486+ // COMMON: scf.for
487+ //
488+ // SYNC-COUNT-2: ttg.local_load
489+ // SYNC: tt.dot
490+ // SYNC: scf.yield
491+ //
492+ // ASYNC: ttg.local_load
493+ // ASYNC: ttg.async_wait
494+ // ASYNC: ttg.local_load
495+ // ASYNC: ttg.dot
496+ // ASYNC: ttg.async_copy_global_to_local
497+
498+ // Epilogue
499+ // COMMON-COUNT-2: ttg.local_load
500+ // COMMON: scf.if
501+ // COMMON: tt.dot
502+ // COMMON-COUNT-2: scf.yield
503+ // COMMON-COUNT-2: ttg.local_dealloc
504+
505+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
506+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [4 , 16 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
507+ #mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 2 ], instrShape = [32 , 32 ], isTransposed = true }>
508+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
509+ tt.func public @simple_matmul_kernel (%test: tensor <1 x64 xi32 , #blocked1 >, %arg0: tensor <64 x64 x!tt.ptr <f16 >, #mma >, %arg1: i32 , %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg3: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg4: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: i32 {tt.divisibility = 16 : i32 }, %arg10: i32 {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
510+ %cst = arith.constant dense <32 > : tensor <64 x32 xi32 , #blocked >
511+ %cst_0 = arith.constant dense <32 > : tensor <32 x64 xi32 , #blocked1 >
512+ %c64_i32 = arith.constant 64 : i32
513+ %c1_i32 = arith.constant 1 : i32
514+ %c0_i32 = arith.constant 0 : i32
515+ %cst_1 = arith.constant dense <0.000000e+00 > : tensor <64 x64 xf32 , #mma >
516+ %0 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
517+ %1 = arith.muli %arg1 , %c64_i32 : i32
518+ %2 = tt.splat %1 : i32 -> tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
519+ %3 = arith.addi %2 , %0 : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
520+ %4 = tt.splat %arg6 : i32 -> tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
521+ %5 = arith.remsi %3 , %4 : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
522+ %6 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
523+ %7 = tt.expand_dims %6 {axis = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x32 xi32 , #blocked >
524+ %8 = tt.broadcast %7 : tensor <1 x32 xi32 , #blocked > -> tensor <64 x32 xi32 , #blocked >
525+ %9 = tt.splat %arg2 : !tt.ptr <f16 > -> tensor <64 x32 x!tt.ptr <f16 >, #blocked >
526+ %10 = tt.addptr %9 , %8 : tensor <64 x32 x!tt.ptr <f16 >, #blocked >, tensor <64 x32 xi32 , #blocked >
527+ %11 = tt.expand_dims %5 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
528+ %12 = tt.broadcast %11 : tensor <1 x64 xi32 , #blocked1 > -> tensor <32 x64 xi32 , #blocked1 >
529+ %13 = tt.splat %arg3 : !tt.ptr <f16 > -> tensor <32 x64 x!tt.ptr <f16 >, #blocked1 >
530+ %14 = tt.addptr %13 , %12 : tensor <32 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <32 x64 xi32 , #blocked1 >
531+ %15:3 = scf.for %arg11 = %c0_i32 to %arg1 step %c1_i32 iter_args (%arg12 = %cst_1 , %arg13 = %10 , %arg14 = %14 ) -> (tensor <64 x64 xf32 , #mma >, tensor <64 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x64 x!tt.ptr <f16 >, #blocked1 >) : i32 {
532+ %17 = tt.load %arg13 : tensor <64 x32 x!tt.ptr <f16 >, #blocked >
533+ %18 = tt.load %arg14 : tensor <32 x64 x!tt.ptr <f16 >, #blocked1 >
534+ %19 = ttg.convert_layout %17 : tensor <64 x32 xf16 , #blocked > -> tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
535+ %20 = ttg.convert_layout %18 : tensor <32 x64 xf16 , #blocked1 > -> tensor <32 x64 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
536+ %21 = tt.dot %19 , %20 , %arg12 , inputPrecision = tf32 : tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <32 x64 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <64 x64 xf32 , #mma >
537+ %22 = tt.addptr %arg13 , %cst : tensor <64 x32 x!tt.ptr <f16 >, #blocked >, tensor <64 x32 xi32 , #blocked >
538+ %23 = tt.addptr %arg14 , %cst_0 : tensor <32 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <32 x64 xi32 , #blocked1 >
539+ scf.yield %21 , %22 , %23 : tensor <64 x64 xf32 , #mma >, tensor <64 x32 x!tt.ptr <f16 >, #blocked >, tensor <32 x64 x!tt.ptr <f16 >, #blocked1 >
540+ }
541+ %16 = arith.truncf %15#0 : tensor <64 x64 xf32 , #mma > to tensor <64 x64 xf16 , #mma >
542+ tt.store %arg0 , %16 : tensor <64 x64 x!tt.ptr <f16 >, #mma >
543+ tt.return
544+ }
545+ }
0 commit comments