Skip to content

Commit 7a156e7

Browse files
Merge commit '0f1e09e308fa71544dd833f768305425c9f2c383'
2 parents af5a09a + 0f1e09e commit 7a156e7

File tree

57 files changed

+5621
-604
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+5621
-604
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ jobs:
109109
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
110110
fi
111111
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
112-
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
112+
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
113113
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py
114114
cd python/test/unit
115115
pytest --capture=tee-sys -rfs -n 12 language runtime \

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ test-unit: all
3737
$(PYTEST) -s -n 8 python/triton_kernels/tests/
3838
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
3939
# Run attention separately to avoid out of gpu memory
40-
TRITON_PRINT_AUTOTUNING=1 $(PYTEST) -vs python/tutorials/06-fused-attention.py
40+
$(PYTEST) -vs python/tutorials/06-fused-attention.py
4141
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4242
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
4343

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,26 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure, MemDescViewTrait]> {
208208

209209
let description = [{
210210
This operation returns a new descriptor representing a subview of the buffer.
211-
It doesn't affect the underlying memory. The subview can be rank-reduced.
211+
It doesn't affect the underlying memory.
212212

213213
For example, suppose that
214214
- the input shape is 2x4x16xf16,
215-
- the output shape is 4x4xf16, and
216-
- offsets = [1, 0, 4].
217-
218-
Then in Python syntax, the subview covers input[1][0:4][4:8].
215+
- the output shape is 4x16xf16, and
216+
- offsets = [1, 0, 0].
217+
218+
Then in Python syntax, the subview covers input[1].
219+
220+
Just one dimension may be split (at most one non-zero offset).
221+
222+
When the input shape and the output shape have different rank:
223+
Or the output shape is a tensor of 1D tensor of 1 element:
224+
- The rank of the output must be 1D smaller than the input.
225+
- We assume the input is split along the 0th dimension.
226+
- The offset along the 0th dimension may be a runtime value.
227+
When the input and the output have the same rank:
228+
- The offset must be a compile-time constant
229+
- Larger or equal to the tile of the tensor (or zero)
230+
- That does not split the input along the swizzling pattern (if any)
219231
}];
220232
let arguments = (
221233
ins TTG_MemDescType:$src, Variadic<I32>:$offsets);

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Types.h"
6+
#include "triton/Tools/LayoutUtils.h"
67

78
using namespace mlir;
89
using namespace mlir::triton;
@@ -421,6 +422,7 @@ struct MemDescSubviewOpConversion
421422
matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor,
422423
ConversionPatternRewriter &rewriter) const override {
423424
Location loc = op->getLoc();
425+
auto *ctx = op->getContext();
424426
auto b = TritonLLVMOpBuilder(loc, rewriter);
425427
auto srcTy = op.getSrc().getType();
426428
auto destTy = op.getResult().getType();
@@ -433,15 +435,42 @@ struct MemDescSubviewOpConversion
433435
llvmElemTy, rewriter);
434436
auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter);
435437
SmallVector<Value> opOffsetVals = op.getOffsets();
438+
// We assume we always create a subview of the last dimensions
436439
SmallVector<Value> opSmemStrides(smemStrides.end() - opOffsetVals.size(),
437440
smemStrides.end());
441+
// Compute total offset
438442
SmallVector<Value> offsetVals;
439443
auto destRank = op.getResult().getType().getRank();
440444
auto rankReduced = srcTy.getRank() - destRank;
441445
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
442446
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
443447
}
444-
Value offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
448+
449+
Value offset;
450+
if (rankReduced || (destTy.getRank() == 1 && destTy.getDimSize(0) == 1)) {
451+
// We are splitting the pipelining dimension which may not be a power of 2
452+
// so we can't use LinearLayouts
453+
offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
454+
} else {
455+
auto dimNames = standardOutDimNames(ctx, opOffsetVals.size());
456+
SmallVector<std::pair<StringAttr, Value>> logicalOffsets;
457+
// This assumes the subviews are additive, in the sense that we can
458+
// compute the offset of one and an add it to the offset of the previous
459+
// one we computed. We check for this in the verifier.
460+
for (int i = 0; i < rankReduced; i++) {
461+
logicalOffsets.push_back({dimNames[i], b.i32_val(0)});
462+
}
463+
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
464+
logicalOffsets.push_back({dimNames[i], offsetVals[i - rankReduced]});
465+
}
466+
// The order gives us the honest-to-goodness layout rank
467+
auto srcAllocShape =
468+
srcTy.getAllocShape().take_back(getOrder(srcTy).size());
469+
auto llInv = toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
470+
offset =
471+
applyLinearLayout(loc, rewriter, llInv, logicalOffsets)[0].second;
472+
}
473+
445474
auto base = smemObj.getBase();
446475
auto elemPtrTy = base.getType();
447476
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton/Dialect/TritonGPU/IR/Types.h"
1111
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1212
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
13+
#include "triton/Tools/LayoutUtils.h"
1314
#include "llvm/Support/Casting.h"
1415
#include "llvm/Support/LogicalResult.h"
1516

@@ -602,15 +603,107 @@ LogicalResult MemDescSubviewOp::verify() {
602603
"offsets other than the first one must be constant zeros");
603604
}
604605
}
606+
return success();
605607
}
606608

607-
// TODO(jlebar): Currently we generate illegal encodings, so we can't add a
608-
// verifier for them. In particular, we use the same encoding for the src and
609-
// dst of a subview op, when the subview removes a dimension. That generates
610-
// an illegal shared encoding (because the size of `order` doesn't match the
611-
// rank of the tensor), but it's not checked anywhere, and we believe the
612-
// resulting code ultimately works.
609+
assert(isa<SharedEncodingTrait>(srcEnc));
613610

611+
// corner case: 1D -> 1D into a 1 element tensor (we don't have 0D tensors)
612+
if (srcTy.getRank() == 1 && dstTy.getRank() == 1 &&
613+
dstTy.getDimSize(0) == 1) {
614+
return success();
615+
}
616+
617+
// There are two cases:
618+
// 1. The subview is rank-reducing
619+
// - We split along the first dimension. It can be with non-constant offsets
620+
if (srcTy.getRank() != dstTy.getRank()) {
621+
if (srcTy.getRank() - dstTy.getRank() != 1) {
622+
return emitError(
623+
"only nD -> (n-1)D rank-reducing subviews are supported");
624+
}
625+
for (auto offset : getOffsets().take_back(dstTy.getRank())) {
626+
if (auto constOp = offset.getDefiningOp<arith::ConstantOp>()) {
627+
if (auto offsetInt = dyn_cast<IntegerAttr>(constOp.getValue())) {
628+
if (offsetInt.getInt() != 0) {
629+
return emitError("only first offset can be non-zero for a "
630+
"rank-reducing subview");
631+
}
632+
} else {
633+
return emitError(
634+
"only integer constant values are allowed for the split");
635+
}
636+
} else {
637+
return emitError("only constant values are allowed outside the front "
638+
"dimension in a rank-reducing subview");
639+
}
640+
}
641+
return success();
642+
}
643+
assert(srcTy.getRank() == dstTy.getRank());
644+
// 2. The src is non-rank-reducing
645+
// - We split along at most one dim, but just with constant values
646+
// - The values where the split happens must not be within the swizzling
647+
// pattern
648+
// Check which dimension we are splitting along
649+
int dim = -1;
650+
for (int i = 0; i < srcTy.getRank(); i++) {
651+
if (srcTy.getDimSize(i) != dstTy.getDimSize(i)) {
652+
if (dim != -1) {
653+
return emitError(
654+
"We don't allow subviews that split along multiple dimensions");
655+
}
656+
dim = i;
657+
}
658+
}
659+
SmallVector<int64_t> offsets;
660+
for (auto offset : getOffsets()) {
661+
if (auto constOp = offset.getDefiningOp<arith::ConstantOp>()) {
662+
if (auto offsetInt = dyn_cast<IntegerAttr>(constOp.getValue())) {
663+
offsets.push_back(offsetInt.getInt());
664+
} else {
665+
return emitError(
666+
"only integer constant values are allowed for the split");
667+
}
668+
} else {
669+
return emitError("only constant values are allowed for the split");
670+
}
671+
}
672+
// Identity subview
673+
if (dim == -1) {
674+
return success();
675+
}
676+
677+
for (auto [i, offset] : llvm::enumerate(offsets)) {
678+
if (i != dim) {
679+
if (offset != 0) {
680+
return emitError("A non zero offset found in a dimension that is "
681+
"not being split");
682+
}
683+
} else {
684+
if (offset & (dstTy.getDimSize(dim) - 1)) {
685+
return emitError("The split offset may not touch the tile");
686+
}
687+
}
688+
}
689+
auto ctx = getContext();
690+
// The order gives us the honest-to-goodness layout rank
691+
auto srcAllocShape = srcTy.getAllocShape().take_back(getOrder(srcTy).size());
692+
auto llInv =
693+
triton::gpu::toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
694+
auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim));
695+
llvm::SmallVector<std::pair<mlir::StringAttr, int32_t>> namedOffsets;
696+
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {
697+
namedOffsets.push_back({d, 0});
698+
}
699+
for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim);
700+
dimSize *= 2) {
701+
namedOffsets[dim] = {kDim, dimSize};
702+
if (!llvm::isPowerOf2_32(llInv.apply(namedOffsets)[0].second)) {
703+
return emitError(
704+
"We don't support splitting along the swizzling pattern");
705+
}
706+
}
614707
return success();
615708
}
616709

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,12 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
188188
if (newOrder != order && op) {
189189
op->emitWarning("Warning: Forcing a different order [")
190190
<< newOrder[0] << ", " << newOrder[1]
191-
<< "] on SMEM than the register order for the opreand " << opIdx
191+
<< "] on SMEM than the register order for the operand " << opIdx
192192
<< ". Registers will be transposed before SMEM store and the pipelined "
193193
"load for this operand will be disabled, so poor performance is "
194-
"expected.";
194+
"expected. Recommendation: consider transposing the operand in "
195+
"global "
196+
"memory to remove the need to transpose the tensor in registers.";
195197
}
196198

197199
Attribute SharedMemorySpace =
@@ -391,9 +393,14 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
391393
int bitwidth = getElementTypeOrSelf(a).getIntOrFloatBitWidth();
392394
a = getDotOperand(a, 0, bitwidth);
393395
} else {
394-
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose);
396+
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose,
397+
/*isMMAv5Fp4Padded=*/false,
398+
/*forceTranspose=*/false, dotOp);
395399
}
396-
b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
400+
b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose,
401+
/*isMMAv5Fp4Padded=*/false,
402+
/*forceTranspose=*/false, dotOp);
403+
397404
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
398405
dotOp.getLoc(), newRetType, a, b, newAcc, nullptr,
399406
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false);

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1616
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1717
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
18-
#include "triton/Tools/LayoutUtils.h"
19-
#include "triton/Tools/LinearLayout.h"
2018
#include "llvm/ADT/MapVector.h"
2119
#include "llvm/ADT/STLExtras.h"
2220
#include "llvm/ADT/SetVector.h"
@@ -287,50 +285,24 @@ SmallVector<Value> splitLhs(OpBuilder &builder,
287285
SmallVector<Value> splitRhs(OpBuilder &builder,
288286
TypedValue<ttg::MemDescType> rhs, int64_t newK) {
289287
auto loc = rhs.getLoc();
290-
auto *ctx = builder.getContext();
291288
auto type = rhs.getType();
292289
auto rank = type.getRank();
293290
auto kDim = rank - 2;
294291
auto nSplits = type.getShape()[kDim] / newK;
295-
// offset -> matrix
296-
auto ll = ttg::toLinearLayout(type.getShape(), type.getEncoding());
297-
auto llInv = ll.invert();
298-
299-
// Split into
300-
auto kOffset = StringAttr::get(ctx, "offset");
301-
assert(llInv.getOutDimSize(kOffset) == product(type.getShape()));
302-
auto dimNames = tt::standardOutDimNames(ctx, rank);
303-
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
304-
for (auto d : getOrder(type)) {
305-
newOutDims.push_back({dimNames[d], type.getShape()[d]});
306-
}
307-
// Split into shmem shape and invert
308-
llInv = llInv.reshapeOuts(newOutDims);
309-
llInv = llInv.transposeOuts(dimNames);
310-
auto toOffsets = [&](const SmallVector<std::pair<StringAttr, int32_t>>
311-
&shape) {
312-
return llvm::to_vector(
313-
llvm::map_range(llvm::make_second_range(shape), [&](int32_t v) {
314-
return builder.create<arith::ConstantIntOp>(loc, v, 32).getResult();
315-
}));
316-
};
317-
// New Shape
318292
auto shape = llvm::to_vector(type.getShape());
319293
shape[kDim] = newK;
294+
SmallVector<Value> offsetsVal;
295+
for (int i = 0; i < rank; i++) {
296+
offsetsVal.push_back(builder.create<arith::ConstantIntOp>(loc, 0, 32));
297+
}
320298
auto newType = ttg::MemDescType::get(
321299
shape, type.getElementType(), type.getEncoding(), type.getMemorySpace(),
322300
/*isMutable=*/false, type.getAllocShape());
323301
SmallVector<Value> ret;
324-
SmallVector<std::pair<StringAttr, int32_t>> logicalOffsets;
325-
for (int i = 0; i < rank; i++) {
326-
logicalOffsets.push_back({StringAttr::get(ctx, "dim" + Twine(i)), 0});
327-
}
328302
for (int i = 0; i < nSplits; i++) {
329-
logicalOffsets[kDim].second = i * newK;
330-
auto shmemOffsets = toOffsets(llInv.apply(logicalOffsets));
331-
303+
offsetsVal[kDim] = builder.create<arith::ConstantIntOp>(loc, i * newK, 32);
332304
Value newSmem = builder.create<triton::gpu::MemDescSubviewOp>(
333-
loc, newType, rhs, shmemOffsets);
305+
loc, newType, rhs, offsetsVal);
334306
ret.push_back(newSmem);
335307
}
336308
return ret;

python/test/unit/language/test_matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,8 @@ def flatten_scale(scale):
546546
print(f"SWP failed for M = {M}, N = {N}")
547547

548548

549-
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128), (64, 128, 32), (128, 256, 32)])
549+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128), (64, 128, 32), (128, 256, 32),
550+
(256, 64, 32)])
550551
@pytest.mark.parametrize("a_trans", [False, True])
551552
@pytest.mark.parametrize("dtype_src_str", ["float32", "float16", "float8e5"])
552553
@pytest.mark.skipif(is_hip() or (is_cuda() and torch.cuda.get_device_capability()[0] != 10),

0 commit comments

Comments
 (0)