Skip to content

Commit 2892466

Browse files
authored
[Blackwell] Fix tmem_subslice lowering for packed sub-32B layouts (#7207)
The actual offset is smaller due to the packed elements.
1 parent 8b7074c commit 2892466

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,31 @@ tt.func private @reinterpret(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_me
606606
}
607607

608608
}
609+
610+
// -----
611+
612+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = false>
613+
#tmem_unpacked = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
614+
615+
module attributes {"ttg.num-warps" = 4 : i32} {
616+
617+
// CHECK-LABEL: @subslice_unpacked
618+
tt.func private @subslice_unpacked(%arg0: !ttg.memdesc<128x128xf16, #tmem_unpacked, #ttng.tensor_memory>) -> !ttg.memdesc<128x64xf16, #tmem_unpacked, #ttng.tensor_memory> {
619+
// CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(64 : i32)
620+
// CHECK: [[PTR:%.*]] = llvm.ptrtoint
621+
// CHECK: llvm.add [[PTR]], [[OFFSET]]
622+
%0 = ttng.tmem_subslice %arg0 {N = 64 : i32} : !ttg.memdesc<128x128xf16, #tmem_unpacked, #ttng.tensor_memory> -> !ttg.memdesc<128x64xf16, #tmem_unpacked, #ttng.tensor_memory>
623+
tt.return %0 : !ttg.memdesc<128x64xf16, #tmem_unpacked, #ttng.tensor_memory>
624+
}
625+
626+
627+
// CHECK-LABEL: @subslice_packed
628+
tt.func private @subslice_packed(%arg0: !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory>) -> !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory> {
629+
// CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(32 : i32)
630+
// CHECK: [[PTR:%.*]] = llvm.ptrtoint
631+
// CHECK: llvm.add [[PTR]], [[OFFSET]]
632+
%0 = ttng.tmem_subslice %arg0 {N = 64 : i32} : !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory>
633+
tt.return %0 : !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory>
634+
}
635+
636+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,13 @@ struct TMEMSubSliceOpConversion
930930
return failure();
931931
}
932932
offsetCol = op.getN();
933+
if (!encoding.getUnpacked()) {
934+
int numElementsPer32B = 32 / srcTy.getElementTypeBitWidth();
935+
if (offsetCol % numElementsPer32B != 0) {
936+
return failure();
937+
}
938+
offsetCol /= numElementsPer32B;
939+
}
933940
Value tmemBase = adaptor.getSrc();
934941
Value offsetVal = b.i32_val(offsetCol | offsetRow << 16);
935942
Value newBase = b.add(b.ptrtoint(i32_ty, tmemBase), offsetVal);

0 commit comments

Comments
 (0)