@@ -3679,3 +3679,44 @@ module attributes {"ttg.num-warps" = 1 : i32, ttg.target = "cuda:80"} {
36793679 tt.return %1 : tensor <2 x16 x2 xf32 , #blocked >
36803680 }
36813681}
3682+
3683+ // -----
3684+
3685+ #linear = #ttg.linear <{register = [[0 , 2 ], [64 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [0 , 1 ]], warp = [[0 , 0 ], [32 , 0 ]], block = []}>
3686+ #linear1 = #ttg.linear <{register = [[0 , 2 ], [64 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [0 , 1 ]], warp = [[32 , 0 ], [0 , 0 ]], block = []}>
3687+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
3688+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
3689+ #mma = #ttg.amd_mfma <{versionMajor = 4 , versionMinor = 0 , warpsPerCTA = [2 , 2 ], instrShape = [32 , 32 ], isTransposed = true }>
3690+ #dot_op_a = #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 16 }>
3691+ #dot_op_b = #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 16 }>
3692+ // CHECK: [[$BLOCK:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
3693+ // CHECK-LABEL: mfma_dot_scaled_no_redundant_convert_layout
3694+ module attributes {" ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
3695+ tt.func public @mfma_dot_scaled_no_redundant_convert_layout (
3696+ %arg0: tensor <128 x128 xf8 E4 M3 FN, #dot_op_a >,
3697+ %arg1: tensor <128 x128 xf8 E4 M3 FN, #dot_op_b >,
3698+ %arg2: tensor <128 x4 xi8 , #linear >,
3699+ %arg3: tensor <128 x4 xi8 , #linear1 >,
3700+ %arg4: tensor <128 x128 x!tt.ptr <f32 >, #blocked >
3701+ ) {
3702+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
3703+ %cst0 = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #blocked1 >
3704+ %c0 = arith.constant 0 : index
3705+ %c1 = arith.constant 1 : index
3706+ %c32 = arith.constant 32 : index
3707+ // CHECK: %[[RET:.+]] = scf.for
3708+ // CHECK-NEXT: %[[DOT_RET:.+]] = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false}
3709+ // CHECK-NEXT: scf.yield %[[DOT_RET]]
3710+ // CHECK-NEXT: }
3711+ // CHECK-NEXT: ttg.convert_layout %[[RET]] : tensor<128x128xf32, #mma> -> tensor<128x128xf32, [[$BLOCK]]>
3712+ // CHECK-NEXT: tt.store
3713+ %1 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args (%arg6 = %cst0 ) -> (tensor <128 x128 xf32 , #blocked1 >) {
3714+ %4 = tt.dot_scaled %arg0 scale %arg2 , %arg1 scale %arg3 , %cst lhs = e4m3 rhs = e4m3 {fastMath = false } : tensor <128 x128 xf8 E4 M3 FN, #dot_op_a >, tensor <128 x4 xi8 , #linear > * tensor <128 x128 xf8 E4 M3 FN, #dot_op_b >, tensor <128 x4 xi8 , #linear1 > -> tensor <128 x128 xf32 , #mma >
3715+ %5 = ttg.convert_layout %4 : tensor <128 x128 xf32 , #mma > -> tensor <128 x128 xf32 , #blocked1 >
3716+ scf.yield %5 : tensor <128 x128 xf32 , #blocked1 >
3717+ }
3718+ %7 = ttg.convert_layout %1 : tensor <128 x128 xf32 , #blocked1 > -> tensor <128 x128 xf32 , #blocked >
3719+ tt.store %arg4 , %7 : tensor <128 x128 x!tt.ptr <f32 >, #blocked >
3720+ tt.return
3721+ }
3722+ }
0 commit comments