@@ -499,3 +499,40 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
499499 tt.return
500500 }
501501}
502+
503+
504+ // -----
505+
506+ #mfma = #triton_gpu.amd_mfma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [8 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
507+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 8 : i32 , triton_gpu.target = " hip:gfx90a" , " triton_gpu.threads-per-warp" = 64 : i32 } {
508+ // CHECK-LABEL: dont_hoist_scf_ops
509+ // Make sure we don't hoist scf ops above its dependencies.
510+ tt.func public @dont_hoist_scf_ops (%init: tensor <256 x128 xf32 , #mfma >,
511+ %base: tensor <256 x128 x!tt.ptr <f16 >, #triton_gpu.dot_op <{opIdx = 0 , parent = #mfma , kWidth = 4 }>>,
512+ %p1: tensor <128 x128 x!tt.ptr <f16 >, #triton_gpu.dot_op <{opIdx = 1 , parent = #mfma , kWidth = 4 }>>, %i1: i1 ) -> (tensor <256 x128 xf32 , #mfma >) {
513+ %c0_i32 = arith.constant 0 : i32
514+ %c1_i32 = arith.constant 1 : i32
515+ %c4_i32 = arith.constant 4 : i32
516+ %cst = arith.constant 1.44269502 : f32
517+ %c128_i32 = arith.constant 128 : i32
518+ // CHECK: scf.for
519+ %54 = scf.for %arg21 = %c0_i32 to %c4_i32 step %c1_i32 iter_args (%arg = %init ) -> (tensor <256 x128 xf32 , #mfma >) : i32 {
520+ // CHECK: arith.addi
521+ %f = arith.addi %arg21 , %c128_i32 : i32
522+ // CHECK: scf.if
523+ // CHECK: tt.load
524+ %p0 = scf.if %i1 -> tensor <256 x128 x!tt.ptr <f16 >, #triton_gpu.dot_op <{opIdx = 0 , parent = #mfma , kWidth = 4 }>>{
525+ %t = tt.splat %f : i32 -> tensor <256 x128 xi32 >
526+ %padd = tt.addptr %base , %t : tensor <256 x128 x!tt.ptr <f16 >, #triton_gpu.dot_op <{opIdx = 0 , parent = #mfma , kWidth = 4 }>>, tensor <256 x128 xi32 >
527+ scf.yield %padd : tensor <256 x128 x!tt.ptr <f16 >, #triton_gpu.dot_op <{opIdx = 0 , parent = #mfma , kWidth = 4 }>>
528+ } else {
529+ scf.yield %base : tensor <256 x128 x!tt.ptr <f16 >, #triton_gpu.dot_op <{opIdx = 0 , parent = #mfma , kWidth = 4 }>>
530+ }
531+ %l = tt.load %p0 : tensor <256 x128 x!tt.ptr <f16 >, #triton_gpu.dot_op <{opIdx = 0 , parent = #mfma , kWidth = 4 }>>
532+ %r = tt.load %p1 : tensor <128 x128 x!tt.ptr <f16 >, #triton_gpu.dot_op <{opIdx = 1 , parent = #mfma , kWidth = 4 }>>
533+ %acc = tt.dot %l , %r , %arg : tensor <256 x128 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mfma , kWidth = 4 }>> * tensor <128 x128 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mfma , kWidth = 4 }>> -> tensor <256 x128 xf32 , #mfma >
534+ scf.yield %acc : tensor <256 x128 xf32 , #mfma >
535+ }
536+ tt.return %54 : tensor <256 x128 xf32 , #mfma >
537+ }
538+ }
0 commit comments