@@ -200,3 +200,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
200200 tt.return
201201 }
202202}
203+
204+ // -----
205+
206+ #linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [64 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 0 ]], warp = [[16 , 0 ], [32 , 0 ]], block = []}>
207+ #linear1 = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [16 , 0 ], [32 , 0 ], [64 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 0 ]], warp = [[0 , 0 ], [0 , 0 ]], block = []}>
208+ #mma = #ttg.amd_wmma <{version = 3 , isTranspose = true , warpsPerCTA = [4 , 1 ], instrShape =[16 , 16 , 128 ]}>
209+
210+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx1250" , " ttg.threads-per-warp" = 32 : i32 } {
211+ // CHECK-LABEL: wmma_scaled_dot_fp8_chained
212+ tt.func @wmma_scaled_dot_fp8_chained (%arg0: tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 16 }>>, %arg2: tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 16 }>>, %arg3: tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 8 }>>, %out0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }) {
213+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
214+ %scale0 = arith.constant dense <127 > : tensor <128 x4 xi8 , #linear >
215+ %scale1 = arith.constant dense <127 > : tensor <128 x4 xi8 , #linear1 >
216+ // CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32>
217+ %mm0 = tt.dot_scaled %arg0 scale %scale0 , %arg2 scale %scale1 , %cst lhs = e4m3 rhs = e4m3 {fastMath = false } : tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 16 }>>, tensor <128 x4 xi8 , #linear > * tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 16 }>>, tensor <128 x4 xi8 , #linear1 > -> tensor <128 x128 xf32 , #mma >
218+ // CHECK-NOT: rocdl.ds_swizzle
219+ // CHECK-NOT: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap"
220+ %op0 = ttg.convert_layout %mm0 : tensor <128 x128 xf32 , #mma > -> tensor <128 x128 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>
221+ %op1 = tt.fp_to_fp %op0 , rounding = rtne : tensor <128 x128 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>> -> tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>
222+ // CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32>
223+ %mm1 = tt.dot_scaled %op1 scale %scale0 , %arg3 scale %scale1 , %cst lhs = e4m3 rhs = e4m3 {fastMath = false } : tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>, tensor <128 x4 xi8 , #linear > * tensor <128 x128 xf8 E5 M2 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 8 }>>, tensor <128 x4 xi8 , #linear1 > -> tensor <128 x128 xf32 , #mma >
224+ %ptr0 = tt.splat %out0 : !tt.ptr <f32 > -> tensor <128 x128 x!tt.ptr <f32 >, #mma >
225+ tt.store %ptr0 , %mm1 : tensor <128 x128 x!tt.ptr <f32 >, #mma >
226+ tt.return
227+ }
228+ }
0 commit comments