Skip to content

Commit 6b06242

Browse files
authored
[BACKEND] Define the semantics of memdesc_subview (#6886)
We strictly define the semantics of `memdesc_subview` as allowing arbitrary skips along the 0-th dimension when the subview is rank-reducing, and otherwise via constant offsets that don't touch the swizzling pattern. We implement a generic lowering that lowers arbitrary layouts under these conditions. These conditions can be very much relaxed and generalised if needed in the future.
1 parent 08973b1 commit 6b06242

File tree

9 files changed

+195
-80
lines changed

9 files changed

+195
-80
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,26 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure, MemDescViewTrait]> {
208208

209209
let description = [{
210210
This operation returns a new descriptor representing a subview of the buffer.
211-
It doesn't affect the underlying memory. The subview can be rank-reduced.
211+
It doesn't affect the underlying memory.
212212

213213
For example, suppose that
214214
- the input shape is 2x4x16xf16,
215-
- the output shape is 4x4xf16, and
216-
- offsets = [1, 0, 4].
217-
218-
Then in Python syntax, the subview covers input[1][0:4][4:8].
215+
- the output shape is 4x16xf16, and
216+
- offsets = [1, 0, 0].
217+
218+
Then in Python syntax, the subview covers input[1].
219+
220+
Just one dimension may be split (at most one non-zero offset).
221+
222+
When the input shape and the output shape have different rank:
223+
Or the output shape is a tensor of 1D tensor of 1 element:
224+
- The rank of the output must be 1D smaller than the input.
225+
- We assume the input is split along the 0th dimension.
226+
- The offset along the 0th dimension may be a runtime value.
227+
When the input and the output have the same rank:
228+
- The offset must be a compile-time constant
229+
- Larger or equal to the tile of the tensor (or zero)
230+
- That does not split the input along the swizzling pattern (if any)
219231
}];
220232
let arguments = (
221233
ins TTG_MemDescType:$src, Variadic<I32>:$offsets);

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Types.h"
6+
#include "triton/Tools/LayoutUtils.h"
67

78
using namespace mlir;
89
using namespace mlir::triton;
@@ -421,6 +422,7 @@ struct MemDescSubviewOpConversion
421422
matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor,
422423
ConversionPatternRewriter &rewriter) const override {
423424
Location loc = op->getLoc();
425+
auto *ctx = op->getContext();
424426
auto b = TritonLLVMOpBuilder(loc, rewriter);
425427
auto srcTy = op.getSrc().getType();
426428
auto destTy = op.getResult().getType();
@@ -433,15 +435,42 @@ struct MemDescSubviewOpConversion
433435
llvmElemTy, rewriter);
434436
auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter);
435437
SmallVector<Value> opOffsetVals = op.getOffsets();
438+
// We assume we always create a subview of the last dimensions
436439
SmallVector<Value> opSmemStrides(smemStrides.end() - opOffsetVals.size(),
437440
smemStrides.end());
441+
// Compute total offset
438442
SmallVector<Value> offsetVals;
439443
auto destRank = op.getResult().getType().getRank();
440444
auto rankReduced = srcTy.getRank() - destRank;
441445
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
442446
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
443447
}
444-
Value offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
448+
449+
Value offset;
450+
if (rankReduced || (destTy.getRank() == 1 && destTy.getDimSize(0) == 1)) {
451+
// We are splitting the pipelining dimension which may not be a power of 2
452+
// so we can't use LinearLayouts
453+
offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
454+
} else {
455+
auto dimNames = standardOutDimNames(ctx, opOffsetVals.size());
456+
SmallVector<std::pair<StringAttr, Value>> logicalOffsets;
457+
// This assumes the subviews are additive, in the sense that we can
458+
// compute the offset of one and an add it to the offset of the previous
459+
// one we computed. We check for this in the verifier.
460+
for (int i = 0; i < rankReduced; i++) {
461+
logicalOffsets.push_back({dimNames[i], b.i32_val(0)});
462+
}
463+
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
464+
logicalOffsets.push_back({dimNames[i], offsetVals[i - rankReduced]});
465+
}
466+
// The order gives us the honest-to-goodness layout rank
467+
auto srcAllocShape =
468+
srcTy.getAllocShape().take_back(getOrder(srcTy).size());
469+
auto llInv = toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
470+
offset =
471+
applyLinearLayout(loc, rewriter, llInv, logicalOffsets)[0].second;
472+
}
473+
445474
auto base = smemObj.getBase();
446475
auto elemPtrTy = base.getType();
447476
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "triton/Dialect/TritonGPU/IR/Types.h"
1010
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1111
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
12+
#include "triton/Tools/LayoutUtils.h"
1213
#include "llvm/Support/Casting.h"
1314
#include "llvm/Support/LogicalResult.h"
1415

@@ -600,15 +601,107 @@ LogicalResult MemDescSubviewOp::verify() {
600601
"offsets other than the first one must be constant zeros");
601602
}
602603
}
604+
return success();
603605
}
604606

605-
// TODO(jlebar): Currently we generate illegal encodings, so we can't add a
606-
// verifier for them. In particular, we use the same encoding for the src and
607-
// dst of a subview op, when the subview removes a dimension. That generates
608-
// an illegal shared encoding (because the size of `order` doesn't match the
609-
// rank of the tensor), but it's not checked anywhere, and we believe the
610-
// resulting code ultimately works.
607+
assert(isa<SharedEncodingTrait>(srcEnc));
611608

609+
// corner case: 1D -> 1D into a 1 element tensor (we don't have 0D tensors)
610+
if (srcTy.getRank() == 1 && dstTy.getRank() == 1 &&
611+
dstTy.getDimSize(0) == 1) {
612+
return success();
613+
}
614+
615+
// There are two cases:
616+
// 1. The subview is rank-reducing
617+
// - We split along the first dimension. It can be with non-constant offsets
618+
if (srcTy.getRank() != dstTy.getRank()) {
619+
if (srcTy.getRank() - dstTy.getRank() != 1) {
620+
return emitError(
621+
"only nD -> (n-1)D rank-reducing subviews are supported");
622+
}
623+
for (auto offset : getOffsets().take_back(dstTy.getRank())) {
624+
if (auto constOp = offset.getDefiningOp<arith::ConstantOp>()) {
625+
if (auto offsetInt = dyn_cast<IntegerAttr>(constOp.getValue())) {
626+
if (offsetInt.getInt() != 0) {
627+
return emitError("only first offset can be non-zero for a "
628+
"rank-reducing subview");
629+
}
630+
} else {
631+
return emitError(
632+
"only integer constant values are allowed for the split");
633+
}
634+
} else {
635+
return emitError("only constant values are allowed outside the front "
636+
"dimension in a rank-reducing subview");
637+
}
638+
}
639+
return success();
640+
}
641+
assert(srcTy.getRank() == dstTy.getRank());
642+
// 2. The src is non-rank-reducing
643+
// - We split along at most one dim, but just with constant values
644+
// - The values where the split happens must not be within the swizzling
645+
// pattern
646+
// Check which dimension we are splitting along
647+
int dim = -1;
648+
for (int i = 0; i < srcTy.getRank(); i++) {
649+
if (srcTy.getDimSize(i) != dstTy.getDimSize(i)) {
650+
if (dim != -1) {
651+
return emitError(
652+
"We don't allow subviews that split along multiple dimensions");
653+
}
654+
dim = i;
655+
}
656+
}
657+
SmallVector<int64_t> offsets;
658+
for (auto offset : getOffsets()) {
659+
if (auto constOp = offset.getDefiningOp<arith::ConstantOp>()) {
660+
if (auto offsetInt = dyn_cast<IntegerAttr>(constOp.getValue())) {
661+
offsets.push_back(offsetInt.getInt());
662+
} else {
663+
return emitError(
664+
"only integer constant values are allowed for the split");
665+
}
666+
} else {
667+
return emitError("only constant values are allowed for the split");
668+
}
669+
}
670+
// Identity subview
671+
if (dim == -1) {
672+
return success();
673+
}
674+
675+
for (auto [i, offset] : llvm::enumerate(offsets)) {
676+
if (i != dim) {
677+
if (offset != 0) {
678+
return emitError("A non zero offset found in a dimension that is "
679+
"not being split");
680+
}
681+
} else {
682+
if (offset & (dstTy.getDimSize(dim) - 1)) {
683+
return emitError("The split offset may not touch the tile");
684+
}
685+
}
686+
}
687+
auto ctx = getContext();
688+
// The order gives us the honest-to-goodness layout rank
689+
auto srcAllocShape = srcTy.getAllocShape().take_back(getOrder(srcTy).size());
690+
auto llInv =
691+
triton::gpu::toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
692+
auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim));
693+
llvm::SmallVector<std::pair<mlir::StringAttr, int32_t>> namedOffsets;
694+
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {
695+
namedOffsets.push_back({d, 0});
696+
}
697+
for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim);
698+
dimSize *= 2) {
699+
namedOffsets[dim] = {kDim, dimSize};
700+
if (!llvm::isPowerOf2_32(llInv.apply(namedOffsets)[0].second)) {
701+
return emitError(
702+
"We don't support splitting along the swizzling pattern");
703+
}
704+
}
612705
return success();
613706
}
614707

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1616
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1717
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
18-
#include "triton/Tools/LayoutUtils.h"
19-
#include "triton/Tools/LinearLayout.h"
2018
#include "llvm/ADT/MapVector.h"
2119
#include "llvm/ADT/STLExtras.h"
2220
#include "llvm/ADT/SetVector.h"
@@ -287,50 +285,24 @@ SmallVector<Value> splitLhs(OpBuilder &builder,
287285
SmallVector<Value> splitRhs(OpBuilder &builder,
288286
TypedValue<ttg::MemDescType> rhs, int64_t newK) {
289287
auto loc = rhs.getLoc();
290-
auto *ctx = builder.getContext();
291288
auto type = rhs.getType();
292289
auto rank = type.getRank();
293290
auto kDim = rank - 2;
294291
auto nSplits = type.getShape()[kDim] / newK;
295-
// offset -> matrix
296-
auto ll = ttg::toLinearLayout(type.getShape(), type.getEncoding());
297-
auto llInv = ll.invert();
298-
299-
// Split into
300-
auto kOffset = StringAttr::get(ctx, "offset");
301-
assert(llInv.getOutDimSize(kOffset) == product(type.getShape()));
302-
auto dimNames = tt::standardOutDimNames(ctx, rank);
303-
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
304-
for (auto d : getOrder(type)) {
305-
newOutDims.push_back({dimNames[d], type.getShape()[d]});
306-
}
307-
// Split into shmem shape and invert
308-
llInv = llInv.reshapeOuts(newOutDims);
309-
llInv = llInv.transposeOuts(dimNames);
310-
auto toOffsets = [&](const SmallVector<std::pair<StringAttr, int32_t>>
311-
&shape) {
312-
return llvm::to_vector(
313-
llvm::map_range(llvm::make_second_range(shape), [&](int32_t v) {
314-
return builder.create<arith::ConstantIntOp>(loc, v, 32).getResult();
315-
}));
316-
};
317-
// New Shape
318292
auto shape = llvm::to_vector(type.getShape());
319293
shape[kDim] = newK;
294+
SmallVector<Value> offsetsVal;
295+
for (int i = 0; i < rank; i++) {
296+
offsetsVal.push_back(builder.create<arith::ConstantIntOp>(loc, 0, 32));
297+
}
320298
auto newType = ttg::MemDescType::get(
321299
shape, type.getElementType(), type.getEncoding(), type.getMemorySpace(),
322300
/*isMutable=*/false, type.getAllocShape());
323301
SmallVector<Value> ret;
324-
SmallVector<std::pair<StringAttr, int32_t>> logicalOffsets;
325-
for (int i = 0; i < rank; i++) {
326-
logicalOffsets.push_back({StringAttr::get(ctx, "dim" + Twine(i)), 0});
327-
}
328302
for (int i = 0; i < nSplits; i++) {
329-
logicalOffsets[kDim].second = i * newK;
330-
auto shmemOffsets = toOffsets(llInv.apply(logicalOffsets));
331-
303+
offsetsVal[kDim] = builder.create<arith::ConstantIntOp>(loc, i * newK, 32);
332304
Value newSmem = builder.create<triton::gpu::MemDescSubviewOp>(
333-
loc, newType, rhs, shmemOffsets);
305+
loc, newType, rhs, offsetsVal);
334306
ret.push_back(newSmem);
335307
}
336308
return ret;

test/Analysis/test-alias.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,10 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B :
116116
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) ->
117117
(!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) {
118118
scf.if %i1 {
119+
%zero = arith.constant 0 : i32
119120
%index = arith.constant 8 : i32
120121
// expected-remark @below {{%4 -> %0,%1}}
121-
%cst0 = ttg.memdesc_subview %a_shared[%index, %index] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable>
122+
%cst0 = ttg.memdesc_subview %a_shared[%index, %zero] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable>
122123
scf.yield
123124
}
124125
scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>

test/Analysis/test-allocation.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,9 @@ tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>
440440
%c_shared_init = ttg.local_alloc : () -> !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>
441441
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>) {
442442
scf.if %i1 {
443+
%zero = arith.constant 0 : i32
443444
%index = arith.constant 8 : i32
444-
%cst0 = ttg.memdesc_subview %a_shared[%index, %index] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable>
445+
%cst0 = ttg.memdesc_subview %a_shared[%index, %zero] : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable>
445446
scf.yield
446447
}
447448
scf.yield %b_shared, %a_shared, %a_shared : !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>, !ttg.memdesc<128x32xf16, #A_SHARED, #ttg.shared_memory, mutable>

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,11 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
373373
%c16_i32 = arith.constant 16 : i32
374374
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xf16>, !llvm.ptr<3>
375375
%0 = ttg.local_alloc %arg0 : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
376-
%1 = ttg.memdesc_subview %0[%c16_i32, %c0_i32] : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable, 64x64>
376+
%1 = ttg.memdesc_subview %0[%c0_i32, %c16_i32] : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 64x64>
377377
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
378-
%2 = ttg.local_load %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable, 64x64> -> tensor<16x64xf16, #blocked>
378+
%2 = ttg.local_load %1 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 64x64> -> tensor<64x16xf16, #blocked>
379379
// CHECK-COUNT-4: llvm.store {{.*}} : vector<1xf16>, !llvm.ptr<3>
380-
ttg.local_store %2, %1 : tensor<16x64xf16, #blocked> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable, 64x64>
380+
ttg.local_store %2, %1 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 64x64>
381381
tt.return
382382
}
383383
}

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
549549
#smem = #ttg.shared_memory
550550
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
551551
// CHECK: llvm.mlir.global external @global_smem
552-
// CHECK-LABEL: basic_subview
553-
tt.func @basic_subview() {
552+
// CHECK-LABEL: rank_reducing_subview
553+
tt.func @rank_reducing_subview() {
554554
// CHECK: llvm.mlir.addressof @global_smem
555555
// CHECK: llvm.extractvalue
556556
// CHECK-NEXT: llvm.extractvalue
@@ -579,33 +579,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
579579

580580
// -----
581581

582-
#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}>
583-
#smem = #ttg.shared_memory
584-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
585-
// CHECK: llvm.mlir.global external @global_smem
586-
// CHECK-LABEL: nvmma_subview
587-
tt.func @nvmma_subview() {
588-
// CHECK: llvm.mlir.addressof @global_smem
589-
// CHECK: llvm.mlir.constant(1 : i32) : i32
590-
// CHECK-NEXT: llvm.mlir.constant(128 : i32) : i32
591-
// CHECK-NEXT: llvm.add
592-
// CHECK-NEXT: llvm.add
593-
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
594-
// CHECK-NEXT: llvm.mul
595-
// CHECK-NEXT: llvm.add
596-
// CHECK-NEXT: llvm.mul
597-
// CHECK-NEXT: llvm.add
598-
// CHECK-NEXT: llvm.getelementptr
599-
%index = arith.constant 1 : i32
600-
%zero = arith.constant 0 : i32
601-
%0 = ttg.local_alloc : () -> !ttg.memdesc<16x128xf32, #shared0, #smem, mutable>
602-
%1 = ttg.memdesc_subview %0[%zero, %zero] : !ttg.memdesc<16x128xf32, #shared0, #smem, mutable> -> !ttg.memdesc<16x32xf32, #shared0, #smem, mutable>
603-
tt.return
604-
}
605-
}
606-
607-
// -----
608-
609582
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
610583
// CHECK-LABEL: basic_async_wait
611584
tt.func @basic_async_wait() {

0 commit comments

Comments
 (0)