@@ -495,6 +495,7 @@ tt.func public @tmem_message_maxnreg_80(%desc: !ttg.memdesc<128x64xf32, #tmem, #
495
495
tt.return
496
496
}
497
497
498
+ // CHECK-LABEL: @module_constraint_supercedes_local
498
499
tt.func public @module_constraint_supercedes_local (%desc: !ttg.memdesc <128 x64 xf32 , #tmem , #ttng.tensor_memory >) {
499
500
ttg.warp_specialize (%desc ) attributes {actualRegisters = array<i32 : 256 , 256 >}
500
501
default {
@@ -611,6 +612,10 @@ tt.func private @reinterpret(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_me
611
612
612
613
#tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = false >
613
614
#tmem_unpacked = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
615
+ #tmem_x1 = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 1 , unpacked = false >
616
+ #tmem_x1_unpacked = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 2 , unpacked = true >
617
+
618
+ #blocked_x1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
614
619
615
620
module attributes {" ttg.num-warps" = 4 : i32 } {
616
621
@@ -633,4 +638,29 @@ tt.func private @subslice_packed(%arg0: !ttg.memdesc<128x128xf16, #tmem, #ttng.t
633
638
tt.return %0 : !ttg.memdesc <128 x64 xf16 , #tmem , #ttng.tensor_memory >
634
639
}
635
640
641
+ // CHECK-LABEL: @load_store_x1
642
+ tt.func @load_store_x1 (%arg0: !ttg.memdesc <128 x1 xf32 , #tmem_x1 , #ttng.tensor_memory , mutable >) {
643
+ %true = arith.constant true
644
+ // CHECK: [[V:%.*]] = llvm.inline_asm {{.*}}tcgen05.ld.sync{{.*}} (i32) -> i32
645
+ // CHECK: [[F:%.*]] = llvm.bitcast [[V]] : i32 to f32
646
+ // CHECK: insertvalue [[F]], {{.*}} : !llvm.struct<(f32)>
647
+ %0 = ttng.tmem_load %arg0 : !ttg.memdesc <128 x1 xf32 , #tmem_x1 , #ttng.tensor_memory , mutable > -> tensor <128 x1 xf32 , #blocked_x1 >
648
+ ttng.tmem_store %0 , %arg0 , %true : tensor <128 x1 xf32 , #blocked_x1 > -> !ttg.memdesc <128 x1 xf32 , #tmem_x1 , #ttng.tensor_memory , mutable >
649
+ tt.return
650
+ }
651
+
652
+ // CHECK-LABEL: @load_store_x1_unpacked
653
+ tt.func @load_store_x1_unpacked (%arg0: !ttg.memdesc <128 x2 xf16 , #tmem_x1_unpacked , #ttng.tensor_memory , mutable >) {
654
+ %true = arith.constant true
655
+ // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32)
656
+ // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32)
657
+ // CHECK: [[V:%.*]] = llvm.inline_asm {{.*}}tcgen05.ld.sync{{.*}} (i32) -> i32
658
+ // CHECK: [[F:%.*]] = llvm.bitcast [[V]] : i32 to vector<2xf16>
659
+ // CHECK: extractelement [[F]][[[C0]] : i32]
660
+ // CHECK: extractelement [[F]][[[C1]] : i32]
661
+ %0 = ttng.tmem_load %arg0 : !ttg.memdesc <128 x2 xf16 , #tmem_x1_unpacked , #ttng.tensor_memory , mutable > -> tensor <128 x2 xf16 , #blocked_x1 >
662
+ ttng.tmem_store %0 , %arg0 , %true : tensor <128 x2 xf16 , #blocked_x1 > -> !ttg.memdesc <128 x2 xf16 , #tmem_x1_unpacked , #ttng.tensor_memory , mutable >
663
+ tt.return
664
+ }
665
+
636
666
}
0 commit comments