@@ -260,3 +260,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
260
260
tt.return
261
261
}
262
262
}
263
+
264
+
265
+ // -----
266
+
267
+ // CHECK-DAG: #[[$SHARED:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
268
+ // CHECK-DAG: #[[$SHARED1:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
269
+ // CHECK-LABEL: @_fbgemm_grouped_gemm_fp8_rowwise_ws
270
+ // CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x64x64xf8E4M3FN, #[[$SHARED1]], #smem, mutable>
271
+ // CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf8E4M3FN, #[[$SHARED1]], #smem, mutable>
272
+ // CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x128xf32, #[[$SHARED]], #smem, mutable>
273
+
274
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
275
+ #blocked1 = #ttg.blocked <{sizePerThread = [1 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
276
+ #mma = #ttg.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 128 , 32 ]}>
277
+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = false , elementBitWidth = 8 }>
278
+ #shared1 = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ]}>
279
+ #shared2 = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = true , elementBitWidth = 8 }>
280
+ #smem = #ttg.shared_memory
281
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:90" , " ttg.threads-per-warp" = 32 : i32 } {
282
+ tt.func public @_fbgemm_grouped_gemm_fp8_rowwise_ws (%arg0: !tt.ptr <i8 , 0 > {tt.nv_tma_desc = 1 : i32 }, %arg1: i32 , %arg2: !tt.ptr <i8 , 0 > {tt.nv_tma_desc = 1 : i32 }, %arg3: !tt.ptr <i8 , 0 > {tt.nv_tma_desc = 1 : i32 }) attributes {noinline = false } {
283
+ %c0_i32 = arith.constant {async_task_id = array<i32 : 0 , 1 , 2 >} 0 : i32
284
+ %c2048_i32 = arith.constant {async_task_id = array<i32 : 0 , 1 , 2 >} 2048 : i32
285
+ %c64_i32 = arith.constant {async_task_id = array<i32 : 0 , 1 , 2 >} 64 : i32
286
+ %cst = arith.constant {async_task_id = array<i32 : 0 , 1 , 2 >} dense <0.000000e+00 > : tensor <64 x128 xf32 , #mma >
287
+ %0 = tt.get_program_id x {async_task_id = array<i32 : 0 , 1 , 2 >} : i32
288
+ %1 = ttng.reinterpret_tensor_descriptor %arg0 {async_task_id = array<i32 : 0 >} : !tt.ptr <i8 , 0 > to !tt.tensordesc <tensor <64 x64 xf8 E4 M3 FN, #shared >>
289
+ %2 = ttng.reinterpret_tensor_descriptor %arg2 {async_task_id = array<i32 : 0 >} : !tt.ptr <i8 , 0 > to !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #shared >>
290
+ %3 = ttng.reinterpret_tensor_descriptor %arg3 {async_task_id = array<i32 : 0 >} : !tt.ptr <i8 , 0 > to !tt.tensordesc <tensor <128 xf32 , #shared1 >>
291
+ scf.for %arg4 = %0 to %arg1 step %c64_i32 : i32 {
292
+ %4 = arith.muli %arg4 , %c2048_i32 {async_task_id = array<i32 : 0 >} : i32
293
+ %5 = scf.for %arg5 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args (%arg6 = %cst ) -> (tensor <64 x128 xf32 , #mma >) : i32 {
294
+ %8 = tt.descriptor_load %1 [%4 , %arg5 ] {async_task_id = array<i32 : 0 >} : !tt.tensordesc <tensor <64 x64 xf8 E4 M3 FN, #shared >> -> tensor <64 x64 xf8 E4 M3 FN, #blocked >
295
+ %9 = ttg.local_alloc %8 {async_task_id = array<i32 : 1 >} : (tensor <64 x64 xf8 E4 M3 FN, #blocked >) -> !ttg.memdesc <64 x64 xf8 E4 M3 FN, #shared , #smem >
296
+ %10 = tt.descriptor_load %2 [%4 , %arg5 ] {async_task_id = array<i32 : 0 >} : !tt.tensordesc <tensor <128 x64 xf8 E4 M3 FN, #shared >> -> tensor <128 x64 xf8 E4 M3 FN, #blocked >
297
+ %11 = ttg.local_alloc %10 {async_task_id = array<i32 : 1 , 2 >} : (tensor <128 x64 xf8 E4 M3 FN, #blocked >) -> !ttg.memdesc <128 x64 xf8 E4 M3 FN, #shared , #smem >
298
+ %12 = ttg.memdesc_trans %11 {async_task_id = array<i32 : 1 , 2 >, order = array<i32 : 1 , 0 >} : !ttg.memdesc <128 x64 xf8 E4 M3 FN, #shared , #smem > -> !ttg.memdesc <64 x128 xf8 E4 M3 FN, #shared2 , #smem >
299
+ %13 = ttng.warp_group_dot %9 , %12 , %arg6 {async_task_id = array<i32 : 1 >, inputPrecision = 0 : i32 , maxNumImpreciseAcc = 1073741824 : i32 } : !ttg.memdesc <64 x64 xf8 E4 M3 FN, #shared , #smem > * !ttg.memdesc <64 x128 xf8 E4 M3 FN, #shared2 , #smem > -> tensor <64 x128 xf32 , #mma >
300
+ scf.yield {async_task_id = array<i32 : 1 , 2 >} %13 : tensor <64 x128 xf32 , #mma >
301
+ } {async_task_id = array<i32 : 0 , 1 , 2 >}
302
+ %6 = tt.descriptor_load %3 [%4 ] {async_task_id = array<i32 : 0 >} : !tt.tensordesc <tensor <128 xf32 , #shared1 >> -> tensor <128 xf32 , #blocked1 >
303
+ %7 = ttg.convert_layout %6 {async_task_id = array<i32 : 1 , 2 >} : tensor <128 xf32 , #blocked1 > -> tensor <128 xf32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
304
+ } {async_task_id = array<i32 : 1 , 2 >}
305
+ tt.return
306
+ }
307
+ }
0 commit comments