@@ -229,6 +229,44 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
229229
230230// -----
231231
232+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = false , elementBitWidth = 8 }>
233+ #shared1 = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = false , elementBitWidth = 8 }>
234+ #shared2 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ]}>
235+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
236+ #tmem_scales = #ttng.tensor_memory_scales_encoding <>
237+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
238+ // CHECK-LABEL: @tc_gen5_mma_block_scale_fp4_a
239+ // CHECK: %[[DESC0:.+]] = llvm.mlir.constant(144769664 : i32) : i32
240+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC0]]
241+ // CHECK: %[[DESC1:.+]] = llvm.mlir.constant(681640592 : i32) : i32
242+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC1]]
243+ // CHECK: %[[DESC2:.+]] = llvm.mlir.constant(1218511520 : i32) : i32
244+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC2]]
245+ // CHECK: %[[DESC3:.+]] = llvm.mlir.constant(1755382448 : i32) : i32
246+ // CHECK: @$7 tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X [ $0 + 0 ], $1, $2, $3, [ $4 + 0 ], [ $5 + 0 ], $6;", "r,l,l,r,r,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %[[DESC3]]
247+ tt.func @tc_gen5_mma_block_scale_fp4_a (%a: !ttg.memdesc <128 x64 xi8 , #shared1 , #ttg.shared_memory >,
248+ %b: !ttg.memdesc <128 x128 xi8 , #shared , #ttg.shared_memory >,
249+ %c: !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >,
250+ %scale_a: !ttg.memdesc <128 x2 xi8 , #tmem_scales , #ttng.tensor_memory >,
251+ %scale_b: !ttg.memdesc <128 x2 xi8 , #tmem_scales , #ttng.tensor_memory >,
252+ %useAcc: i1 ,
253+ %pred: i1 ,
254+ %barrier: !ttg.memdesc <1 xi64 , #shared2 , #ttg.shared_memory , mutable >) {
255+ ttng.tc_gen5_mma_scaled %a , %b , %c , %scale_a , %scale_b , %useAcc , %pred lhs = e2m1 rhs = e4m3 , %barrier :
256+ (!ttg.memdesc <128 x64 xi8 , #shared1 , #ttg.shared_memory >,
257+ !ttg.memdesc <128 x128 xi8 , #shared , #ttg.shared_memory >,
258+ !ttg.memdesc <128 x128 xf32 , #tmem , #ttng.tensor_memory , mutable >,
259+ !ttg.memdesc <128 x2 xi8 , #tmem_scales , #ttng.tensor_memory >,
260+ !ttg.memdesc <128 x2 xi8 , #tmem_scales , #ttng.tensor_memory >,
261+ i1 ,
262+ i1 ,
263+ !ttg.memdesc <1 xi64 , #shared2 , #ttg.shared_memory , mutable >) -> ()
264+ tt.return
265+ }
266+ }
267+
268+ // -----
269+
232270#shared = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = false , elementBitWidth = 16 , CTAsPerCGA = [2 , 1 ], CTASplitNum = [2 , 1 ], CTAOrder = [1 , 0 ]}>
233271#shared1 = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = false , elementBitWidth = 16 , CTAsPerCGA = [1 , 2 ], CTASplitNum = [1 , 2 ], CTAOrder = [1 , 0 ]}>
234272#shared2 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ], CTAsPerCGA = [2 ], CTASplitNum = [1 ], CTAOrder = [0 ]}>
0 commit comments