Skip to content

Commit 45a7da6

Browse files
authored
[TritonGPU] Split MemDescSubview into MemDescIndex and MemDescSubslice (#7622)
The first one will be used just for pipelining and it's equivalent to `x[i]`, the second one takes a full slice of constant shape `x[:i1, :i2]`, for example.
1 parent 2bc0672 commit 45a7da6

File tree

65 files changed

+874
-912
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+874
-912
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,12 +356,12 @@ class SharedMemoryObject {
356356
RewriterBase &rewriter) const;
357357

358358
// Returns a mask representing all the bits of the memdesc offsets that
359-
// may be modified by an affine offset coming from a memdesc_subview.
359+
// may be modified by an affine offset coming from a memdesc_subslice.
360360
// The offsets are considered to be in the type of the memdesc.
361361
// For padded layouts, we return the offsets without padding.
362362
static uint64_t getMaskSpanOffsets(triton::gpu::MemDescType srcTy);
363363

364-
// Returns whether the shared memory access had a memdesc_subview
364+
// Returns whether the shared memory access had a memdesc_subslice
365365
// that is rank-preserving (soon to be called memdesc_slice)
366366
static bool isAffineSharedMemoryAccess(triton::gpu::MemDescType srcTy) {
367367
return getMaskSpanOffsets(srcTy) != 0;

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -200,38 +200,57 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc"> {
200200
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
201201
let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}];
202202
}
203-
204-
def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure, MemDescViewTrait]> {
203+
def TTG_MemDescIndexOp : TTG_Op<"memdesc_index", [Pure, MemDescViewTrait]> {
205204
let summary = "take a subview of the descriptor.";
206205

207206
let description = [{
208-
This operation returns a new descriptor representing a subview of the buffer.
207+
This operation returns a new descriptor pointing to the `i`-th element of the
208+
input descriptor along the 0-th dimension.
209+
209210
It doesn't affect the underlying memory.
210211

211212
For example, suppose that
212213
- the input shape is 2x4x16xf16,
213214
- the output shape is 4x16xf16, and
214-
- offsets = [1, 0, 0].
215+
- index = 1.
216+
Then the output descriptor is equivalent to input[1], where input is the logical tensor.
215217

216-
Then in Python syntax, the subview covers input[1].
218+
When the input is of rank 1 (i.e, shape=[k]), the output will have shape=[1].
219+
}];
217220

218-
Just one dimension may be split (at most one non-zero offset).
221+
let arguments = (ins TTG_MemDescType:$src, I32:$index);
219222

220-
When the input shape and the output shape have different rank:
221-
Or the output shape is a tensor of 1D tensor of 1 element:
222-
- The rank of the output must be 1D smaller than the input.
223-
- We assume the input is split along the 0th dimension.
224-
- The offset along the 0th dimension may be a runtime value.
225-
When the input and the output have the same rank:
226-
- The offset must be a compile-time constant
227-
- Larger or equal to the tile of the tensor (or zero)
228-
- That does not split the input along the swizzling pattern (if any)
229-
}];
230-
let arguments = (
231-
ins TTG_MemDescType:$src, Variadic<I32>:$offsets);
223+
let results = (outs TTG_MemDescType:$result);
224+
225+
let assemblyFormat = [{$src `,` $index attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];
226+
227+
let hasVerifier = 1;
228+
}
232229

230+
def TTG_MemDescSubsliceOp : TTG_Op<"memdesc_subslice", [Pure, MemDescViewTrait]> {
231+
let summary = "take a subview of the descriptor.";
232+
233+
let description = [{
234+
This operation returns a new descriptor representing a subview of the logical tensor.
235+
It doesn't affect the underlying memory.
236+
237+
For example, suppose that
238+
- the input shape is 32x16xf16,
239+
- the output shape is 8x16xf16, and
240+
- offsets = [2, 1].
241+
Then in Python syntax, the subview covers input[2:8+2, 1:16+1] where input is
242+
the logical tensor.
243+
244+
The offsets must be larger or equal to the tile of the tensor (or zero).
245+
}];
246+
let arguments = (ins TTG_MemDescType:$src, DenseI32ArrayAttr:$offsets);
233247
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
234-
let assemblyFormat = [{$src `[` $offsets `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];
248+
// Render offsets inline as %src[0, 0] via a custom directive, but keep
249+
// the overall parse/print generated from this assemblyFormat.
250+
let assemblyFormat = [{
251+
$src `[` custom<Offsets>($offsets) `]` attr-dict `:` qualified(type($src))
252+
`->` qualified(type($result))
253+
}];
235254

236255
let results = (outs TTG_MemDescType:$result);
237256

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ gpu::SharedEncodingTrait getSharedEncoding(Operation *loadOp);
142142
// specified.
143143
int getNumStagesOrDefault(scf::ForOp forOp, int defaultNumStages);
144144

145-
// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a
145+
// Given a result of MemDescIndex, or Alloca, create a MemDescIndex with a
146146
// single buffer slice (leading dimension equal to 1), at the given index.
147147
TypedValue<triton::gpu::MemDescType>
148148
createSingleBufferView(OpBuilder &builder, Value alloc, Value idx);
149-
// Given a result of MemDescSubview, or Alloca, create a MemDescSubview with a
149+
// Given a result of MemDescIndex, or Alloca, create a MemDescIndex with a
150150
// single buffer slice (leading dimension equal to 1), at the given index.
151151
TypedValue<triton::gpu::MemDescType>
152152
createSingleBufferView(OpBuilder &builder, Value alloc, int idx);

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,8 @@ def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure]> {
674674
let description = [{
675675
This operation takes a subslice of a tensor memory allocation and returns a new descriptor
676676
containing the address and a view of the subslice.
677-
This is similar to ttg.memdesc_subview except the offset needs to be static and we can only
678-
slice alog the inner dimension of a 2D memdesc as this is the only one we can do for TMem.
677+
This is similar to ttg.memdesc_subslice except we can only slice along the inner dimension
678+
of a 2D memdesc as this is the only one we can do for TMem.
679679
}];
680680
let arguments = (ins TTG_MemDescType:$src, I32Attr:$N);
681681

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ bool emitTransferBetweenRegistersAndShared(
705705
{kBlock, blockId}},
706706
regIds);
707707

708-
// Compute affine offset given by memdesc_subview
708+
// Compute affine offset given by memdesc_subslice
709709
auto offset = smemObj.getShmemOffset(loc, rewriter, sharedTy);
710710
SmallVector<Value> vecAddrVec;
711711
for (auto &indices : indicesVec) {
@@ -1153,7 +1153,7 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter,
11531153
auto ctx = srcTy.getContext();
11541154
auto b = TritonLLVMOpBuilder(loc, rewriter);
11551155

1156-
// If it did not have a memdesc_subview, we don't need to compute the offset
1156+
// If it did not have a memdesc_subslice we don't need to compute the offset
11571157
// as it is zero
11581158
if (!isAffineSharedMemoryAccess(srcTy)) {
11591159
return b.i32_val(0);

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -465,13 +465,46 @@ struct BroadcastOpConversion
465465
}
466466
};
467467

468-
struct MemDescSubviewOpConversion
469-
: public ConvertOpToLLVMPattern<triton::gpu::MemDescSubviewOp> {
468+
struct MemDescIndexOpConversion
469+
: public ConvertOpToLLVMPattern<triton::gpu::MemDescIndexOp> {
470470
using ConvertOpToLLVMPattern<
471-
triton::gpu::MemDescSubviewOp>::ConvertOpToLLVMPattern;
471+
triton::gpu::MemDescIndexOp>::ConvertOpToLLVMPattern;
472472

473473
LogicalResult
474-
matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor,
474+
matchAndRewrite(triton::gpu::MemDescIndexOp op, OpAdaptor adaptor,
475+
ConversionPatternRewriter &rewriter) const override {
476+
Location loc = op->getLoc();
477+
auto *ctx = op->getContext();
478+
auto b = TritonLLVMOpBuilder(loc, rewriter);
479+
auto srcTy = op.getSrc().getType();
480+
auto destTy = op.getResult().getType();
481+
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
482+
483+
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
484+
llvmElemTy, rewriter);
485+
auto base = smemObj.getBase();
486+
auto elemPtrTy = base.getType();
487+
Value stride = smemObj.getStrides(srcTy, loc, rewriter).front();
488+
Value offset = b.mul(op.getIndex(), stride);
489+
auto prevOffsets = smemObj.getOffsets();
490+
SmallVector<Value> offsetVals(prevOffsets.end() - destTy.getRank(),
491+
prevOffsets.end());
492+
// Advance the pointer and keep the opOffsets as the new shape
493+
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),
494+
llvmElemTy, offsetVals);
495+
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
496+
rewriter.replaceOp(op, retVal);
497+
return success();
498+
}
499+
};
500+
501+
struct MemDescSubsliceOpConversion
502+
: public ConvertOpToLLVMPattern<triton::gpu::MemDescSubsliceOp> {
503+
using ConvertOpToLLVMPattern<
504+
triton::gpu::MemDescSubsliceOp>::ConvertOpToLLVMPattern;
505+
506+
LogicalResult
507+
matchAndRewrite(triton::gpu::MemDescSubsliceOp op, OpAdaptor adaptor,
475508
ConversionPatternRewriter &rewriter) const override {
476509
Location loc = op->getLoc();
477510
auto *ctx = op->getContext();
@@ -484,40 +517,17 @@ struct MemDescSubviewOpConversion
484517

485518
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
486519
llvmElemTy, rewriter);
487-
SmallVector<Value> opOffsetVals = op.getOffsets();
488-
// We assume we always create a subview of the last dimensions
489-
// Compute total offset
490-
auto rankReduced = srcTy.getRank() - destTy.getRank();
520+
auto opOffsetVals = op.getOffsets();
491521

492522
auto base = smemObj.getBase();
493523
auto elemPtrTy = base.getType();
494-
auto is1d = srcTy.getRank() == 1 && destTy.getRank() == 1 &&
495-
destTy.getDimSize(0) == 1;
496-
if (rankReduced || is1d) {
497-
auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter);
498-
SmallVector<Value> opSmemStrides(smemStrides.end() - opOffsetVals.size(),
499-
smemStrides.end());
500-
// We are splitting the pipelining dimension which may not be a power of 2
501-
// so we can't use LinearLayouts
502-
auto offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
503-
// Remove the first offsets
504-
SmallVector<Value> offsetVals;
505-
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
506-
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
507-
}
508-
// Advance the pointer and keep the opOffsets as the new shape
509-
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),
510-
llvmElemTy, offsetVals);
511-
} else {
512-
// Accumulate the logical offsets
513-
SmallVector<Value> offsetVals;
514-
for (auto [oldOff, newOff] :
515-
llvm::zip(smemObj.getOffsets(), opOffsetVals)) {
516-
offsetVals.push_back(b.add(oldOff, newOff));
517-
}
518-
smemObj = SharedMemoryObject(base, llvmElemTy, offsetVals);
524+
// Accumulate the logical offsets
525+
SmallVector<Value> offsetVals;
526+
for (auto [oldOffVal, opOff] :
527+
llvm::zip(smemObj.getOffsets(), opOffsetVals)) {
528+
offsetVals.push_back(b.add(oldOffVal, b.i32_val(opOff)));
519529
}
520-
530+
smemObj = SharedMemoryObject(base, llvmElemTy, offsetVals);
521531
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
522532
rewriter.replaceOp(op, retVal);
523533
return success();
@@ -563,6 +573,7 @@ void mlir::triton::populateViewOpToLLVMPatterns(
563573
typeConverter, benefit);
564574
patterns.add<TransOpConversion>(typeConverter, benefit);
565575
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
566-
patterns.add<MemDescSubviewOpConversion>(typeConverter, benefit);
576+
patterns.add<MemDescSubsliceOpConversion, MemDescIndexOpConversion>(
577+
typeConverter, benefit);
567578
patterns.add<MemDescReinterpretOpConversion>(typeConverter, benefit);
568579
}

0 commit comments

Comments
 (0)