44
55#acc_layout = #ttg.blocked <{sizePerThread = [1 , 128 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
66#oper_layout = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
7+ #oper_layout_trans = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [2 , 2 ], order = [0 , 1 ]}>
78// CHECK-DAG: [[SHARED:#.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
89#shared = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 16 }>
910#shared_trans = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = true , elementBitWidth = 16 }>
11+ #nvmma_smem = #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }>
1012#smem = #ttg.shared_memory
1113// CHECK-DAG: [[ACC_TMEM:#.*]] = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
1214#acc_tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
@@ -791,8 +793,8 @@ tt.func @matmul_scaled_rhs_scales_tma(
791793 %k_tiles: i32 ,
792794 %off_m: i32 ,
793795 %off_n: i32 ,
794- %a_desc: !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #ttg.nvmma_shared <{ swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }> >>,
795- %b_desc: !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #ttg.nvmma_shared <{ swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }> >>,
796+ %a_desc: !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #nvmma_smem >>,
797+ %b_desc: !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #nvmma_smem >>,
796798 %b_scale_desc: !tt.tensordesc <tensor <128 x8 xi8 , #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [4 , 3 , 2 , 1 , 0 ]}>>>
797799) {
798800 %true = arith.constant true
@@ -814,14 +816,14 @@ tt.func @matmul_scaled_rhs_scales_tma(
814816
815817 // CHECK: ttng.wait_barrier
816818 // CHECK-COUNT-3: async_tma_copy_global_to_local {{.*}} {ttg.partition = 2 : i32}
817- %a_reg = tt.descriptor_load %a_desc [%off_m , %off_k ] : !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #ttg.nvmma_shared <{ swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }> >> -> tensor <128 x64 xf8 E4 M3 FN, #oper_layout >
818- %b_reg = tt.descriptor_load %b_desc [%off_n , %off_k ] : !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #ttg.nvmma_shared <{ swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }> >> -> tensor <128 x64 xf8 E4 M3 FN, #oper_layout >
819+ %a_reg = tt.descriptor_load %a_desc [%off_m , %off_k ] : !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #nvmma_smem >> -> tensor <128 x64 xf8 E4 M3 FN, #oper_layout >
820+ %b_reg = tt.descriptor_load %b_desc [%off_n , %off_k ] : !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #nvmma_smem >> -> tensor <128 x64 xf8 E4 M3 FN, #oper_layout >
819821 %b_scales_reg = tt.descriptor_load %b_scale_desc [%off_m , %c0_i32 ] : !tt.tensordesc <tensor <128 x8 xi8 , #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [4 , 3 , 2 , 1 , 0 ]}>>> -> tensor <128 x8 xi8 , #oper_layout >
820822
821- %a_sh = ttg.local_alloc %a_reg : (tensor <128 x64 xf8 E4 M3 FN, #oper_layout >) -> !ttg.memdesc <128 x64 xf8 E4 M3 FN, #ttg.nvmma_shared <{ swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }> , #smem >
822- %b_sh_raw = ttg.local_alloc %b_reg : (tensor <128 x64 xf8 E4 M3 FN, #oper_layout >) -> !ttg.memdesc <128 x64 xf8 E4 M3 FN, #ttg.nvmma_shared <{ swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }> , #smem >
823+ %a_sh = ttg.local_alloc %a_reg : (tensor <128 x64 xf8 E4 M3 FN, #oper_layout >) -> !ttg.memdesc <128 x64 xf8 E4 M3 FN, #nvmma_smem , #smem >
824+ %b_sh_raw = ttg.local_alloc %b_reg : (tensor <128 x64 xf8 E4 M3 FN, #oper_layout >) -> !ttg.memdesc <128 x64 xf8 E4 M3 FN, #nvmma_smem , #smem >
823825 // CHECK-NEXT: memdesc_trans {{.*}} ttg.partition = 1 : i32
824- %b_sh = ttg.memdesc_trans %b_sh_raw {order = array<i32 : 1 , 0 >} : !ttg.memdesc <128 x64 xf8 E4 M3 FN, #ttg.nvmma_shared <{ swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }> , #smem > -> !ttg.memdesc <64 x128 xf8 E4 M3 FN, #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = true , elementBitWidth = 8 }>, #smem >
826+ %b_sh = ttg.memdesc_trans %b_sh_raw {order = array<i32 : 1 , 0 >} : !ttg.memdesc <128 x64 xf8 E4 M3 FN, #nvmma_smem , #smem > -> !ttg.memdesc <64 x128 xf8 E4 M3 FN, #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = true , elementBitWidth = 8 }>, #smem >
825827
826828 // CHECK-NEXT: wait_barrier {{.*}} {ttg.partition = 1 : i32}
827829
@@ -831,7 +833,7 @@ tt.func @matmul_scaled_rhs_scales_tma(
831833
832834 // CHECK-NEXT: [[IS_LAST:%.*]] = arith.cmpi eq, %arg6, [[LAST_ITER]]
833835 // CHECK-NEXT: tc_gen5_mma_scaled {{.*}} {ttg.partition = 1 : i32}
834- %mma_tok = ttng.tc_gen5_mma_scaled %a_sh , %b_sh , %c_tmem [%c_tok ], %a_scales_tmem , %b_scales_tmem , %true , %true lhs = e4m3 rhs = e4m3 : !ttg.memdesc <128 x64 xf8 E4 M3 FN, #ttg.nvmma_shared <{ swizzlingByteWidth = 128 , transposed = false , elementBitWidth = 8 }> , #smem >, !ttg.memdesc <64 x128 xf8 E4 M3 FN, #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = true , elementBitWidth = 8 }>, #smem >, !ttg.memdesc <128 x128 xf32 , #acc_tmem , #ttng.tensor_memory , mutable >, !ttg.memdesc <128 x8 xi8 , #ttng.tensor_memory_scales_encoding <>, #ttng.tensor_memory >, !ttg.memdesc <128 x8 xi8 , #ttng.tensor_memory_scales_encoding <>, #ttng.tensor_memory >
836+ %mma_tok = ttng.tc_gen5_mma_scaled %a_sh , %b_sh , %c_tmem [%c_tok ], %a_scales_tmem , %b_scales_tmem , %true , %true lhs = e4m3 rhs = e4m3 : !ttg.memdesc <128 x64 xf8 E4 M3 FN, #nvmma_smem , #smem >, !ttg.memdesc <64 x128 xf8 E4 M3 FN, #ttg.nvmma_shared <{swizzlingByteWidth = 128 , transposed = true , elementBitWidth = 8 }>, #smem >, !ttg.memdesc <128 x128 xf32 , #acc_tmem , #ttng.tensor_memory , mutable >, !ttg.memdesc <128 x8 xi8 , #ttng.tensor_memory_scales_encoding <>, #ttng.tensor_memory >, !ttg.memdesc <128 x8 xi8 , #ttng.tensor_memory_scales_encoding <>, #ttng.tensor_memory >
835837
836838 %c , %load_tok = ttng.tmem_load %c_tmem [%mma_tok ] : !ttg.memdesc <128 x128 xf32 , #acc_tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #acc_layout >
837839 scf.yield %c : tensor <128 x128 xf32 , #acc_layout >
@@ -1125,6 +1127,56 @@ tt.func @specialize_mma_only(%rhs_desc: !tt.tensordesc<tensor<64x128xf16, #share
11251127 tt.return
11261128}
11271129
1130+ // CHECK-LABEL: @load_scale_mma_user
1131+ tt.func @load_scale_mma_user (
1132+ %lhs: !ttg.memdesc <128 x64 xf16 , #shared , #smem >,
1133+ %rhs: !ttg.memdesc <64 x128 xf16 , #shared , #smem >,
1134+ %scales_desc: !tt.tensordesc <tensor <8 x128 xi8 , #shared >>,
1135+ %b_scales: !ttg.memdesc <128 x8 xi8 , #ttng.tensor_memory_scales_encoding <>, #ttng.tensor_memory >,
1136+ %ub: i32
1137+ ) {
1138+ %c0_i32 = arith.constant 0 : i32
1139+ %c1_i32 = arith.constant 1 : i32
1140+ %true = arith.constant true
1141+ %zero = arith.constant dense <0.0 > : tensor <128 x128 xf32 , #acc_layout >
1142+
1143+ // CHECK: scf.for
1144+ %out = scf.for %i = %c0_i32 to %ub step %c1_i32 iter_args (%acc = %zero ) -> tensor <128 x128 xf32 , #acc_layout > : i32 {
1145+ // CHECK: wait_barrier [[EMPTY_BAR:%.*]], %{{.*}}partition = 2
1146+ // CHECK: barrier_expect [[SCALES_BAR:%.*]], 1024 {{.*}}partition = 2
1147+ // CHECK: async_tma_copy_global_to_local {{.*}}partition = 2
1148+ %scales_result = tt.descriptor_load %scales_desc [%i , %i ] : !tt.tensordesc <tensor <8 x128 xi8 , #shared >> -> tensor <8 x128 xi8 , #oper_layout >
1149+ %scales_shared = ttg.local_alloc %scales_result : (tensor <8 x128 xi8 , #oper_layout >) -> !ttg.memdesc <8 x128 xi8 , #shared , #smem >
1150+ // CHECK: wait_barrier [[SCALES_BAR]]{{.*}}partition = 0
1151+ // CHECK-NEXT: [[SCALES_REG:%.*]] = ttg.local_load {{.*}}partition = 0
1152+ // CHECK-NEXT: arrive_barrier [[EMPTY_BAR]]{{.*}}partition = 0
1153+ %scales_reg = ttg.local_load %scales_shared : !ttg.memdesc <8 x128 xi8 , #shared , #smem > -> tensor <8 x128 xi8 , #oper_layout >
1154+ // CHECK-NEXT: [[SCALES_TRANS:%.*]] = tt.trans [[SCALES_REG]] {{.*}}partition = 0
1155+ %scales_T = tt.trans %scales_reg {order = array<i32 : 1 , 0 >} : tensor <8 x128 xi8 , #oper_layout > -> tensor <128 x8 xi8 , #oper_layout_trans >
1156+ // CHECK-NEXT: wait_barrier [[SCALES_TMEM_BAR:%.*]], %arg{{[0-9]+}} {{.*}}partition = 0
1157+ // CHECK-NEXT: tmem_store [[SCALES_TRANS]], [[SCALES_TMEM:%.*]], %true {{.*}}partition = 0
1158+ %scales_tmem = ttng.tmem_alloc %scales_T : (tensor <128 x8 xi8 , #oper_layout_trans >) -> !ttg.memdesc <128 x8 xi8 , #ttng.tensor_memory_scales_encoding <>, #ttng.tensor_memory >
1159+ // CHECK-NEXT: arrive_barrier [[SCALES_READY_BAR:%.*]], 1 {{.*}}partition = 0
1160+
1161+ // CHECK: wait_barrier [[SCALES_READY_BAR]]{{.*}}partition = 1
1162+ %acc_tmem , %acc_tok = ttng.tmem_alloc %acc : (tensor <128 x128 xf32 , #acc_layout >) -> (!ttg.memdesc <128 x128 xf32 , #acc_tmem , #ttng.tensor_memory , mutable >, !ttg.async.token )
1163+ // CHECK-NEXT: tc_gen5_mma_scaled {{.*}} [[SCALES_TMEM]]{{.*}} [[USER_BAR:%.*]][%true], [[SCALES_TMEM_BAR]][%true] {{.*}}partition = 1
1164+ %mma_tok = ttng.tc_gen5_mma_scaled %lhs , %rhs , %acc_tmem [%acc_tok ], %scales_tmem , %b_scales , %true , %true lhs = e4m3 rhs = e4m3 : !ttg.memdesc <128 x64 xf16 , #shared , #smem >, !ttg.memdesc <64 x128 xf16 , #shared , #smem >, !ttg.memdesc <128 x128 xf32 , #acc_tmem , #ttng.tensor_memory , mutable >, !ttg.memdesc <128 x8 xi8 , #ttng.tensor_memory_scales_encoding <>, #ttng.tensor_memory >, !ttg.memdesc <128 x8 xi8 , #ttng.tensor_memory_scales_encoding <>, #ttng.tensor_memory >
1165+
1166+ // CHECK: wait_barrier [[USER_BAR]]{{.*}}partition = 0
1167+ // CHECK-NEXT: tmem_load
1168+ %c , %load_tok = ttng.tmem_load %acc_tmem [%mma_tok ] : !ttg.memdesc <128 x128 xf32 , #acc_tmem , #ttng.tensor_memory , mutable > -> tensor <128 x128 xf32 , #acc_layout >
1169+ // CHECK: arrive_barrier [[USER_DONE:%.*]], 1 {{.*}}partition = 0
1170+ // CHECK: wait_barrier [[USER_DONE]]{{.*}}partition = 1
1171+
1172+ " user" (%c ) : (tensor <128 x128 xf32 , #acc_layout >) -> ()
1173+
1174+ scf.yield %c : tensor <128 x128 xf32 , #acc_layout >
1175+ } {tt.warp_specialize , tt.num_stages = 3 : i32 }
1176+ " use" (%out ) : (tensor <128 x128 xf32 , #acc_layout >) -> ()
1177+ tt.return
1178+ }
1179+
11281180}
11291181
11301182// -----
0 commit comments