Skip to content

Commit eb66546

Browse files
authored
[TTGIR] Disallow memdesc_index 1D -> 1D (#7673)
We make memdesc_index always be rank-reducing.
1 parent c2d0d5e commit eb66546

File tree

18 files changed

+140
-139
lines changed

18 files changed

+140
-139
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ def TTG_MemDescIndexOp : TTG_Op<"memdesc_index", [Pure, MemDescViewTrait]> {
214214
- the output shape is 4x16xf16, and
215215
- index = 1.
216216
Then the output descriptor is equivalent to input[1], where input is the logical tensor.
217-
218-
When the input is of rank 1 (i.e, shape=[k]), the output will have shape=[1].
219217
}];
220218

221219
let arguments = (ins TTG_MemDescType:$src, I32:$index);

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,17 +477,21 @@ struct MemDescIndexOpConversion
477477
auto *ctx = op->getContext();
478478
auto b = TritonLLVMOpBuilder(loc, rewriter);
479479
auto srcTy = op.getSrc().getType();
480-
auto destTy = op.getResult().getType();
480+
auto dstTy = op.getResult().getType();
481481
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
482482

483+
// getAllocationShapePerCTA returns the correct number fp4 elements that we
484+
// need to skip when we have fp4Padded=True. getShapePerCTA does not account
485+
// for this
486+
auto stride = product(
487+
getAllocationShapePerCTA(dstTy.getEncoding(), dstTy.getShape()));
488+
Value offset = b.mul(op.getIndex(), b.i32_val(stride));
483489
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
484490
llvmElemTy, rewriter);
485491
auto base = smemObj.getBase();
486492
auto elemPtrTy = base.getType();
487-
Value stride = smemObj.getStrides(srcTy, loc, rewriter).front();
488-
Value offset = b.mul(op.getIndex(), stride);
489493
auto prevOffsets = smemObj.getOffsets();
490-
SmallVector<Value> offsetVals(prevOffsets.end() - destTy.getRank(),
494+
SmallVector<Value> offsetVals(prevOffsets.end() - dstTy.getRank(),
491495
prevOffsets.end());
492496
// Advance the pointer and keep the opOffsets as the new shape
493497
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -690,19 +690,17 @@ LogicalResult MemDescIndexOp::verify() {
690690
if (srcTy.getElementType() != dstTy.getElementType()) {
691691
return emitError("result element type must match desc element type");
692692
}
693-
bool is1D =
694-
srcTy.getRank() == 1 && dstTy.getRank() == 1 && dstTy.getDimSize(0) == 1;
695-
bool correctRank = srcTy.getRank() == dstTy.getRank() + 1 || is1D;
693+
// memdesc_index reduces rank by 1 and preserves the trailing shape.
694+
bool correctRank = srcTy.getRank() == dstTy.getRank() + 1;
696695
if (!correctRank) {
697-
return emitError(
698-
"result rank must be less than or equal to input rank or 1D -> 1D");
696+
return emitError("result rank must be input rank - 1");
699697
}
700698
if (srcTy.getAllocShape().size() != srcTy.getRank()) {
701699
return emitError("We don't allow taking memdesc_index of a memdesc_index");
702700
}
703701

704-
if (!is1D && ArrayRef(srcTy.getShape()).take_back(dstTy.getRank()) !=
705-
dstTy.getShape()) {
702+
if (ArrayRef(srcTy.getShape()).take_back(dstTy.getRank()) !=
703+
dstTy.getShape()) {
706704
return emitError("result shape must equal to srcShape[1:]");
707705
}
708706

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ Value mlir::triton::createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type,
459459
auto barrierEncoding =
460460
ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCTALayout);
461461
ttg::MemDescType memDescType = ttg::MemDescType::get(
462-
{numBuffers}, type, barrierEncoding, sharedMemorySpace,
462+
{numBuffers, 1}, type, barrierEncoding, sharedMemorySpace,
463463
/*mutableMemory=*/true);
464464
return rewriter.create<ttg::LocalAllocOp>(memDescType, Value());
465465
}
@@ -653,12 +653,10 @@ triton::createSingleBufferView(OpBuilder &builder, Value alloc, Value idx) {
653653
assert(isa<ttg::MemDescType>(alloc.getType()) && "Expected MemDescType");
654654
auto allocDescType = cast<ttg::MemDescType>(alloc.getType());
655655
SmallVector<int64_t> shape;
656-
if (allocDescType.getShape().size() > 1) {
657-
shape.insert(shape.end(), allocDescType.getShape().begin() + 1,
658-
allocDescType.getShape().end());
659-
} else {
660-
shape.push_back(1);
661-
}
656+
assert(allocDescType.getShape().size() > 1 &&
657+
"Expected multi-dimensional memdesc (e.g., Nx...) for subview");
658+
shape.insert(shape.end(), allocDescType.getShape().begin() + 1,
659+
allocDescType.getShape().end());
662660
auto viewDescType = ttg::MemDescType::get(
663661
shape, allocDescType.getElementType(), allocDescType.getEncoding(),
664662
allocDescType.getMemorySpace(), allocDescType.getMutableMemory(),

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -785,25 +785,32 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
785785
Value emptyBar = createBarrierAlloc(loop, /*numBarriers=*/1);
786786
Value readyBar = createBarrierAlloc(loop, /*numBarriers=*/1);
787787
PartitionBuilder b(defs.front()->getLoc(), loop);
788-
b.create<ttng::ArriveBarrierOp>(emptyBar, /*arriveCount=*/1);
788+
// For Nx1 barrier allocations, pass a 1D view into barrier ops.
789+
Value emptyView0 = createSingleBufferView(b, emptyBar, b.intCst(0));
790+
b.create<ttng::ArriveBarrierOp>(emptyView0, /*arriveCount=*/1);
789791

790792
Operation *domOp = findNearestCommonDominator(defs, domInfo);
791793
Operation *lastOp = findNearestCommonPostDominator(defs, postDomInfo);
792794

793795
auto [index, phase] = addIndexAndPhase(b, loop, /*numStages=*/1);
794796
StageCluster srcStageCluster = getStageCluster(domOp);
795797
b.setInsertionPoint(domOp);
796-
b.createInto<ttng::WaitBarrierOp>(*partition, srcStageCluster, emptyBar,
798+
Value emptyView = createSingleBufferView(b, emptyBar, index);
799+
b.createInto<ttng::WaitBarrierOp>(*partition, srcStageCluster, emptyView,
797800
phase);
798801

799802
b.setInsertionPointAfter(lastOp);
800-
b.createInto<ttng::ArriveBarrierOp>(*partition, srcStageCluster, readyBar,
803+
Value readyView = createSingleBufferView(b, readyBar, index);
804+
b.createInto<ttng::ArriveBarrierOp>(*partition, srcStageCluster, readyView,
801805
1);
802806

803807
b.setInsertionPoint(mmaOp);
808+
Value readyView2 = createSingleBufferView(b, readyBar, index);
804809
b.createInto<ttng::WaitBarrierOp>(*schedule.getPartition(mmaOp),
805-
getStageCluster(mmaOp), readyBar, phase);
806-
mmaOp.addCompletionBarrier(emptyBar, b.boolCst(true));
810+
getStageCluster(mmaOp), readyView2,
811+
phase);
812+
Value emptyView2 = createSingleBufferView(b, emptyBar, index);
813+
mmaOp.addCompletionBarrier(emptyView2, b.boolCst(true));
807814
mmaOp.setIsAsync(true);
808815
}
809816

test/Analysis/test-allocation.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,9 +832,9 @@ tt.func @aliasing_in_partition() {
832832
}
833833
partition0() num_warps(4) {
834834
// expected-remark @below {{offset = 0, size = 16}}
835-
%0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>
835+
%0 = ttg.local_alloc : () -> !ttg.memdesc<2x1xi64, #A_SHARED, #smem, mutable>
836836
%c0_i32 = arith.constant 0 : i32
837-
%1 = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> -> !ttg.memdesc<1xi64, #A_SHARED, #smem, mutable>
837+
%1 = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<2x1xi64, #A_SHARED, #smem, mutable> -> !ttg.memdesc<1xi64, #A_SHARED, #smem, mutable>
838838
// expected-remark @below {{offset = 16, size = 16}}
839839
%2 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>
840840
"use"(%1) : (!ttg.memdesc<1xi64, #A_SHARED, #smem, mutable>) -> ()

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -552,14 +552,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
552552
// CHECK-LABEL: rank_reducing_subview
553553
tt.func @rank_reducing_subview() {
554554
// CHECK: llvm.mlir.addressof @global_smem
555-
// CHECK: llvm.extractvalue
555+
// CHECK: llvm.mlir.constant(512 : i32) : i32
556+
// CHECK-NEXT: llvm.mul
557+
// CHECK-NEXT: llvm.extractvalue
556558
// CHECK-NEXT: llvm.extractvalue
557559
// CHECK-NEXT: llvm.extractvalue
558560
// CHECK-NEXT: llvm.extractvalue
559-
// CHECK-NEXT: llvm.mlir.constant(1 : i32) : i32
560-
// CHECK-NEXT: llvm.mlir.constant(32 : i32) : i32
561-
// CHECK-NEXT: llvm.mlir.constant(512 : i32) : i32
562-
// CHECK-NEXT: llvm.mul
563561
// CHECK-NEXT: llvm.getelementptr
564562
%index = arith.constant 1 : i32
565563
%zero = arith.constant 0 : i32
@@ -2111,8 +2109,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
21112109
// CHECK: llvm.store
21122110
tt.func public @test_local_store_subview(%arg0: tensor<1xf32, #blocked>) {
21132111
%c0_i32 = arith.constant 0 : i32
2114-
%0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1xf32, #shared, #smem, mutable>
2115-
%sv = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<1xf32, #shared, #smem, mutable> -> !ttg.memdesc<1xf32, #shared, #smem, mutable>
2112+
%0 = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<1x1xf32, #shared, #smem, mutable>
2113+
%sv = ttg.memdesc_index %0, %c0_i32 : !ttg.memdesc<1x1xf32, #shared, #smem, mutable> -> !ttg.memdesc<1xf32, #shared, #smem, mutable>
21162114
ttg.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !ttg.memdesc<1xf32, #shared, #smem, mutable>
21172115
tt.return
21182116
}

test/TritonGPU/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ tt.func public @result_rank_too_large(%arg0: !ttg.memdesc<3x8x16xf32, #shared, #
7979
#smem = #ttg.shared_memory
8080
tt.func public @result_1d_to_1d(%arg0: !ttg.memdesc<8xf32, #shared, #smem>) {
8181
%zero = arith.constant 0 : i32
82-
// expected-error @+1 {{1D -> 1D}}
82+
// expected-error @+1 {{result rank}}
8383
%a = ttg.memdesc_index %arg0, %zero : !ttg.memdesc<8xf32, #shared, #smem> -> !ttg.memdesc<2xf32, #shared, #smem>
8484
tt.return
8585
}

0 commit comments

Comments
 (0)