@@ -606,3 +606,31 @@ tt.func private @reinterpret(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_me
606606}
607607
608608}
609+
610+ // -----
611+
612+ #tmem = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = false >
613+ #tmem_unpacked = #ttng.tensor_memory_encoding <blockM = 128 , blockN = 128 , unpacked = true >
614+
615+ module attributes {" ttg.num-warps" = 4 : i32 } {
616+
617+ // CHECK-LABEL: @subslice_unpacked
618+ tt.func private @subslice_unpacked (%arg0: !ttg.memdesc <128 x128 xf16 , #tmem_unpacked , #ttng.tensor_memory >) -> !ttg.memdesc <128 x64 xf16 , #tmem_unpacked , #ttng.tensor_memory > {
619+ // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(64 : i32)
620+ // CHECK: [[PTR:%.*]] = llvm.ptrtoint
621+ // CHECK: llvm.add [[PTR]], [[OFFSET]]
622+ %0 = ttng.tmem_subslice %arg0 {N = 64 : i32 } : !ttg.memdesc <128 x128 xf16 , #tmem_unpacked , #ttng.tensor_memory > -> !ttg.memdesc <128 x64 xf16 , #tmem_unpacked , #ttng.tensor_memory >
623+ tt.return %0 : !ttg.memdesc <128 x64 xf16 , #tmem_unpacked , #ttng.tensor_memory >
624+ }
625+
626+
627+ // CHECK-LABEL: @subslice_packed
628+ tt.func private @subslice_packed (%arg0: !ttg.memdesc <128 x128 xf16 , #tmem , #ttng.tensor_memory >) -> !ttg.memdesc <128 x64 xf16 , #tmem , #ttng.tensor_memory > {
629+ // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(32 : i32)
630+ // CHECK: [[PTR:%.*]] = llvm.ptrtoint
631+ // CHECK: llvm.add [[PTR]], [[OFFSET]]
632+ %0 = ttng.tmem_subslice %arg0 {N = 64 : i32 } : !ttg.memdesc <128 x128 xf16 , #tmem , #ttng.tensor_memory > -> !ttg.memdesc <128 x64 xf16 , #tmem , #ttng.tensor_memory >
633+ tt.return %0 : !ttg.memdesc <128 x64 xf16 , #tmem , #ttng.tensor_memory >
634+ }
635+
636+ }
0 commit comments