Skip to content

Commit 84f0906

Browse files
authored
[BACKEND] Support memory subview for NVMMASharedEncodingAttr (#6241)
1 parent 3887b80 commit 84f0906

File tree

2 files changed

+83
-7
lines changed

2 files changed

+83
-7
lines changed

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,10 @@ struct MemDescSubviewOpConversion
392392
Location loc = op->getLoc();
393393
auto b = TritonLLVMOpBuilder(loc, rewriter);
394394
auto srcTy = op.getSrc().getType();
395+
auto destTy = op.getResult().getType();
395396
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
396397
auto layoutOrder = getOrder(srcTy);
398+
auto enc = srcTy.getEncoding();
397399

398400
// newBase = base + offset
399401
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
@@ -408,13 +410,49 @@ struct MemDescSubviewOpConversion
408410
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
409411
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
410412
}
411-
// Compute the offset based on the original strides of the shared memory
412-
// object
413-
auto offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
414-
auto elemPtrTy = smemObj.getBase().getType();
415-
smemObj = SharedMemoryObject(
416-
b.gep(elemPtrTy, llvmElemTy, smemObj.getBase(), offset), llvmElemTy,
417-
offsetVals);
413+
Value offset = b.undef(i32_ty);
414+
auto allocShape = srcTy.getAllocShape();
415+
bool isSimpleSubview =
416+
allocShape.take_back(destRank) == destTy.getShape() ||
417+
!isa<NVMMASharedEncodingAttr>(enc);
418+
if (!isSimpleSubview) {
419+
auto nvmmaEnc = cast<NVMMASharedEncodingAttr>(enc);
420+
assert(destRank >= 2 &&
421+
"Shape size should be >= 2 when using NVMMAShared encoding");
422+
auto swizzleStride = b.i32_val((nvmmaEnc.getSwizzlingByteWidth() * 8) /
423+
llvmElemTy.getIntOrFloatBitWidth());
424+
offset = b.i32_val(0);
425+
for (auto i = 0; i < opOffsetVals.size() - 2; ++i) {
426+
offset = b.add(offset, b.mul(opOffsetVals[i], opSmemStrides[i]));
427+
}
428+
// newOffset = offset - (stridedOff * swizzledStride + contigOff /
429+
// swizzledStride * tileSize + contigOff % swizzledStride)
430+
// + stridedInc * swizzledStride + contigInc / swizzledStride *
431+
// tileSize + contigInc % swizzledStride
432+
auto stridedDim = destRank - 1 - layoutOrder[0];
433+
auto contigDim = destRank - 1 - layoutOrder[1];
434+
auto stridedOff = smemObj.getOffsets()[stridedDim];
435+
auto contigOff = smemObj.getOffsets()[contigDim];
436+
auto stridedInc = offsetVals[stridedDim];
437+
auto contigInc = offsetVals[contigDim];
438+
int allocStridedDim = allocShape.size() - 1 - layoutOrder[0];
439+
auto tileSize =
440+
b.mul(b.i32_val(allocShape[allocStridedDim]), swizzleStride);
441+
offset = b.sub(offset, b.mul(stridedOff, swizzleStride));
442+
offset = b.sub(offset, b.mul(b.udiv(contigOff, swizzleStride), tileSize));
443+
offset = b.sub(offset, b.urem(contigOff, swizzleStride));
444+
offset = b.add(offset, b.mul(stridedInc, swizzleStride));
445+
offset = b.add(offset, b.mul(b.udiv(contigInc, swizzleStride), tileSize));
446+
offset = b.add(offset, b.urem(contigInc, swizzleStride));
447+
} else {
448+
// Compute the offset based on the original strides of the shared memory
449+
// object
450+
offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
451+
}
452+
auto base = smemObj.getBase();
453+
auto elemPtrTy = base.getType();
454+
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),
455+
llvmElemTy, offsetVals);
418456
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
419457
rewriter.replaceOp(op, retVal);
420458
return success();

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
532532
// CHECK-NEXT: llvm.mlir.constant(512 : i32) : i32
533533
// CHECK-NEXT: llvm.add
534534
// CHECK-NEXT: llvm.add
535+
// CHECK-NEXT: llvm.mlir.undef
535536
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
536537
// CHECK-NEXT: llvm.mul
537538
// CHECK-NEXT: llvm.add
@@ -550,6 +551,43 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
550551

551552
// -----
552553

554+
#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
555+
#smem = #ttg.shared_memory
556+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
557+
// CHECK: llvm.mlir.global external @global_smem
558+
// CHECK-LABEL: nvmma_subview
559+
tt.func @nvmma_subview() {
560+
// CHECK: llvm.mlir.addressof @global_smem
561+
// CHECK: llvm.mlir.undef : i32
562+
// CHECK-NEXT: llvm.mlir.constant(32 : i32) : i32
563+
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
564+
// CHECK-NEXT: llvm.mlir.constant(16 : i32) : i32
565+
// CHECK-NEXT: llvm.mul
566+
// CHECK-NEXT: llvm.mul
567+
// CHECK-NEXT: llvm.sub
568+
// CHECK-NEXT: llvm.udiv
569+
// CHECK-NEXT: llvm.mul
570+
// CHECK-NEXT: llvm.sub
571+
// CHECK-NEXT: llvm.urem
572+
// CHECK-NEXT: llvm.sub
573+
// CHECK-NEXT: llvm.mul
574+
// CHECK-NEXT: llvm.add
575+
// CHECK-NEXT: llvm.udiv
576+
// CHECK-NEXT: llvm.mul
577+
// CHECK-NEXT: llvm.add
578+
// CHECK-NEXT: llvm.urem
579+
// CHECK-NEXT: llvm.add
580+
// CHECK-NEXT: llvm.getelementptr
581+
%index = arith.constant 1 : i32
582+
%zero = arith.constant 0 : i32
583+
%0 = ttg.local_alloc : () -> !ttg.memdesc<16x128xf32, #shared0, #smem, mutable>
584+
%1 = ttg.memdesc_subview %0[%zero, %zero] : !ttg.memdesc<16x128xf32, #shared0, #smem, mutable> -> !ttg.memdesc<16x32xf32, #shared0, #smem, mutable>
585+
tt.return
586+
}
587+
}
588+
589+
// -----
590+
553591
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
554592
// CHECK-LABEL: basic_async_wait
555593
tt.func @basic_async_wait() {

0 commit comments

Comments
 (0)