@@ -301,7 +301,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
301
301
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
302
302
module attributes {" ttg.target" = " cuda:90" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
303
303
tt.func @linear_to_swizzled_st_matrix_local_store (%a: tensor <64 x32 xf16 , #linear >) {
304
- // CHECK-COUNT-2: nvgpu.stmatrix
304
+ // CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
305
305
// CHECK: llvm.return
306
306
%b = ttg.local_alloc {allocation.offset = 0 : i32 } : () -> !ttg.memdesc <64 x32 xf16 , #shared , #smem , mutable >
307
307
ttg.local_store %a , %b : tensor <64 x32 xf16 , #linear > -> !ttg.memdesc <64 x32 xf16 , #shared , #smem , mutable >
@@ -323,7 +323,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
323
323
// CHECK-LABEL: linear_to_swizzled_st_matrix_local_store
324
324
module attributes {" ttg.target" = " cuda:90" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
325
325
tt.func @linear_to_swizzled_st_matrix_local_store (%a: tensor <32 x32 xf16 , #linear >) {
326
- // CHECK-COUNT-2: nvgpu.stmatrix
326
+ // CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
327
327
// CHECK: llvm.return
328
328
%b = ttg.local_alloc {allocation.offset = 0 : i32 } : () -> !ttg.memdesc <32 x32 xf16 , #shared , #smem , mutable >
329
329
ttg.local_store %a , %b : tensor <32 x32 xf16 , #linear > -> !ttg.memdesc <32 x32 xf16 , #shared , #smem , mutable >
@@ -333,6 +333,38 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
333
333
334
334
// -----
335
335
336
+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
337
+ #linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [8 , 0 ]], lane = [[0 , 4 ], [0 , 8 ], [1 , 0 ], [2 , 0 ], [4 , 0 ]], warp = [[16 , 0 ], [32 , 0 ]], block = []}>
338
+ #smem = #ttg.shared_memory
339
+ // CHECK-LABEL: linear_to_swizzled_st_matrix_x2_local_store_fp8
340
+ module attributes {" ttg.target" = " cuda:90" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
341
+ tt.func @linear_to_swizzled_st_matrix_x2_local_store_fp8 (%a: tensor <64 x16 xf8 E4 M3 FNUZ, #linear >) {
342
+ // CHECK-COUNT-1: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}} :
343
+ // CHECK: llvm.return
344
+ %b = ttg.local_alloc {allocation.offset = 0 : i32 } : () -> !ttg.memdesc <64 x16 xf8 E4 M3 FNUZ, #shared , #smem , mutable >
345
+ ttg.local_store %a , %b : tensor <64 x16 xf8 E4 M3 FNUZ, #linear > -> !ttg.memdesc <64 x16 xf8 E4 M3 FNUZ, #shared , #smem , mutable >
346
+ tt.return
347
+ }
348
+ }
349
+
350
+ // -----
351
+
352
+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
353
+ #linear = #ttg.linear <{register = [[8 , 0 ], [0 , 4 ], [0 , 8 ]], lane = [[0 , 1 ], [0 , 2 ], [1 , 0 ], [2 , 0 ], [4 , 0 ]], warp = [[16 , 0 ], [32 , 0 ]], block = []}>
354
+ #smem = #ttg.shared_memory
355
+ // CHECK-LABEL: linear_to_swizzled_st_matrix_local_store_fp32
356
+ module attributes {" ttg.target" = " cuda:90" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
357
+ tt.func @linear_to_swizzled_st_matrix_local_store_fp32 (%a: tensor <64 x16 xf32 , #linear >) {
358
+ // CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
359
+ // CHECK: llvm.return
360
+ %b = ttg.local_alloc {allocation.offset = 0 : i32 } : () -> !ttg.memdesc <64 x16 xf32 , #shared , #smem , mutable >
361
+ ttg.local_store %a , %b : tensor <64 x16 xf32 , #linear > -> !ttg.memdesc <64 x16 xf32 , #shared , #smem , mutable >
362
+ tt.return
363
+ }
364
+ }
365
+
366
+ // -----
367
+
336
368
#blocked = #ttg.blocked <{sizePerThread = [8 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
337
369
module attributes {" ttg.target" = " cuda:90" , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
338
370
tt.func @fp8_const (%arg0: tensor <1024 xi1 , #blocked >, %arg1: tensor <1024 xf8 E4 M3 FNUZ, #blocked >) attributes {noinline = false } {
0 commit comments