@@ -260,3 +260,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
260260 tt.return
261261 }
262262}
263+
264+
265+ // -----
266+
267+ // CHECK-DAG: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
268+ // CHECK-DAG: #[[$SHARED1:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
269+ // CHECK-LABEL: @_fbgemm_grouped_gemm_fp8_rowwise_ws
270+ // CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x64x64xf8E4M3FN, #[[$SHARED1]], #smem, mutable>
271+ // CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf8E4M3FN, #[[$SHARED1]], #smem, mutable>
272+ // CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x128xf32, #[[$SHARED]], #smem, mutable>
273+
274+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
275+ #blocked1 = #ttg.blocked <{sizePerThread = [1 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
276+ #mma = #ttg.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 128 , 32 ]}>
277+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = false , elementBitWidth = 8 }>
278+ #shared1 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ]}>
279+ #shared2 = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = true , elementBitWidth = 8 }>
280+ #smem = #ttg.shared_memory
281+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:90" , " ttg.threads-per-warp" = 32 : i32 } {
282+ tt.func public @_fbgemm_grouped_gemm_fp8_rowwise_ws (%arg0: !tt.ptr <i8 , 0 > {tt.nv_tma_desc = 1 : i32 }, %arg1: i32 , %arg2: !tt.ptr <i8 , 0 > {tt.nv_tma_desc = 1 : i32 }, %arg3: !tt.ptr <i8 , 0 > {tt.nv_tma_desc = 1 : i32 }) attributes {noinline = false } {
283+ %c0_i32 = arith.constant {async_task_id = array<i32 : 0 , 1 , 2 >} 0 : i32
284+ %c2048_i32 = arith.constant {async_task_id = array<i32 : 0 , 1 , 2 >} 2048 : i32
285+ %c64_i32 = arith.constant {async_task_id = array<i32 : 0 , 1 , 2 >} 64 : i32
286+ %cst = arith.constant {async_task_id = array<i32 : 0 , 1 , 2 >} dense <0.000000e+00 > : tensor <64 x128 xf32 , #mma >
287+ %0 = tt.get_program_id x {async_task_id = array<i32 : 0 , 1 , 2 >} : i32
288+ %1 = ttng.reinterpret_tensor_descriptor %arg0 {async_task_id = array<i32 : 0 >} : !tt.ptr <i8 , 0 > to !tt.tensordesc <tensor <64 x64 xf8 E4 M3 FN, #shared >>
289+ %2 = ttng.reinterpret_tensor_descriptor %arg2 {async_task_id = array<i32 : 0 >} : !tt.ptr <i8 , 0 > to !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #shared >>
290+ %3 = ttng.reinterpret_tensor_descriptor %arg3 {async_task_id = array<i32 : 0 >} : !tt.ptr <i8 , 0 > to !tt.tensordesc <tensor <128 xf32 , #shared1 >>
291+ scf.for %arg4 = %0 to %arg1 step %c64_i32 : i32 {
292+ %4 = arith.muli %arg4 , %c2048_i32 {async_task_id = array<i32 : 0 >} : i32
293+ %5 = scf.for %arg5 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args (%arg6 = %cst ) -> (tensor <64 x128 xf32 , #mma >) : i32 {
294+ %8 = tt.descriptor_load %1 [%4 , %arg5 ] {async_task_id = array<i32 : 0 >} : !tt.tensordesc <tensor <64 x64 xf8 E4 M3 FN, #shared >> -> tensor <64 x64 xf8 E4 M3 FN, #blocked >
295+ %9 = ttg.local_alloc %8 {async_task_id = array<i32 : 1 >} : (tensor <64 x64 xf8 E4 M3 FN, #blocked >) -> !ttg.memdesc <64 x64 xf8 E4 M3 FN, #shared , #smem >
296+ %10 = tt.descriptor_load %2 [%4 , %arg5 ] {async_task_id = array<i32 : 0 >} : !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #shared >> -> tensor <128 x64 xf8 E4 M3 FN, #blocked >
297+ %11 = ttg.local_alloc %10 {async_task_id = array<i32 : 1 , 2 >} : (tensor <128 x64 xf8 E4 M3 FN, #blocked >) -> !ttg.memdesc <128 x64 xf8 E4 M3 FN, #shared , #smem >
298+ %12 = ttg.memdesc_trans %11 {async_task_id = array<i32 : 1 , 2 >, order = array<i32 : 1 , 0 >} : !ttg.memdesc <128 x64 xf8 E4 M3 FN, #shared , #smem > -> !ttg.memdesc <64 x128 xf8 E4 M3 FN, #shared2 , #smem >
299+ %13 = ttng.warp_group_dot %9 , %12 , %arg6 {async_task_id = array<i32 : 1 >, inputPrecision = 0 : i32 , maxNumImpreciseAcc = 1073741824 : i32 } : !ttg.memdesc <64 x64 xf8 E4 M3 FN, #shared , #smem > * !ttg.memdesc <64 x128 xf8 E4 M3 FN, #shared2 , #smem > -> tensor <64 x128 xf32 , #mma >
300+ scf.yield {async_task_id = array<i32 : 1 , 2 >} %13 : tensor <64 x128 xf32 , #mma >
301+ } {async_task_id = array<i32 : 0 , 1 , 2 >}
302+ %6 = tt.descriptor_load %3 [%4 ] {async_task_id = array<i32 : 0 >} : !tt.tensordesc <tensor <128 xf32 , #shared1 >> -> tensor <128 xf32 , #blocked1 >
303+ %7 = ttg.convert_layout %6 {async_task_id = array<i32 : 1 , 2 >} : tensor <128 xf32 , #blocked1 > -> tensor <128 xf32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
304+ } {async_task_id = array<i32 : 1 , 2 >}
305+ tt.return
306+ }
307+ }
0 commit comments