Skip to content

Commit d9b6052

Browse files
Merge OpenAI Triton commit 607c50c (#4284)
This PR change the Triton base from e6b9efd to 607c50c (May 18). Pass rate: 95.34%->93.75% (#4289)
2 parents 0c81300 + 2db25c5 commit d9b6052

File tree

117 files changed

+9378
-2840
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

117 files changed

+9378
-2840
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ jobs:
109109
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
110110
fi
111111
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
112-
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
112+
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
113113
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py
114114
cd python/test/unit
115115
pytest --capture=tee-sys -rfs -n 12 language runtime \

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ llvm-project-*/
1010
dist/
1111
triton*.egg-info/
1212
*.whl
13+
python/triton_kernels/triton*.egg-info/
1314

1415
python/triton/_C/*.pyd
1516
python/triton/_C/*.so

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ test-unit: all
3636
$(PYTEST) -s -n 8 python/test/unit/test_debug.py --forked
3737
$(PYTEST) -s -n 8 python/triton_kernels/tests/
3838
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
39-
# Run cuda/test_flashattention.py separately to avoid out of gpu memory
40-
$(PYTEST) -s python/test/unit/cuda/test_flashattention.py
39+
# Run attention separately to avoid out of gpu memory
40+
$(PYTEST) -vs python/tutorials/06-fused-attention.py
4141
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4242
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
4343

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6161
mlir::registerAllPasses();
6262
mlir::triton::registerTritonPasses();
6363
mlir::triton::gpu::registerTritonGPUPasses();
64-
mlir::registerTritonNvidiaGPUPasses();
64+
mlir::triton::nvidia_gpu::registerTritonNvidiaGPUPasses();
6565
mlir::test::intel::registerTestAxisInfoPass();
6666
mlir::test::registerTestAliasPass();
6767
mlir::test::registerTestAlignmentPass();

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,26 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure, MemDescViewTrait]> {
208208

209209
let description = [{
210210
This operation returns a new descriptor representing a subview of the buffer.
211-
It doesn't affect the underlying memory. The subview can be rank-reduced.
211+
It doesn't affect the underlying memory.
212212

213213
For example, suppose that
214214
- the input shape is 2x4x16xf16,
215-
- the output shape is 4x4xf16, and
216-
- offsets = [1, 0, 4].
217-
218-
Then in Python syntax, the subview covers input[1][0:4][4:8].
215+
- the output shape is 4x16xf16, and
216+
- offsets = [1, 0, 0].
217+
218+
Then in Python syntax, the subview covers input[1].
219+
220+
Just one dimension may be split (at most one non-zero offset).
221+
222+
When the input shape and the output shape have different rank:
223+
Or the output shape is a tensor of 1D tensor of 1 element:
224+
- The rank of the output must be 1D smaller than the input.
225+
- We assume the input is split along the 0th dimension.
226+
- The offset along the 0th dimension may be a runtime value.
227+
When the input and the output have the same rank:
228+
- The offset must be a compile-time constant
229+
- Larger or equal to the tile of the tensor (or zero)
230+
- That does not split the input along the swizzling pattern (if any)
219231
}];
220232
let arguments = (
221233
ins TTG_MemDescType:$src, Variadic<I32>:$offsets);

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,16 @@ def TritonGPUOptimizePartitionWarps : Pass<"tritongpu-optimize-partition-warps",
165165
}];
166166
}
167167

168+
def TritonGPUPartitionScheduling : Pass<"tritongpu-partition-scheduling", "mlir::ModuleOp"> {
169+
let summary = "warp specialization partitioning pass";
170+
171+
let description = [{
172+
The `tritongpu-partition-scheduling` analyzes the loads, MMAs, and other
173+
operations in a loop that is meant to be warp specialized and determines
174+
which partitions to assign to each operation.
175+
}];
176+
}
177+
168178
def TritonGPULoadMMASpecialization : Pass<"tritongpu-load-mma-specialization", "mlir::ModuleOp"> {
169179
let summary = "load MMA specialization";
170180

@@ -219,23 +229,6 @@ def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
219229
"mlir::arith::ArithDialect"];
220230
}
221231

222-
def TritonGPUWGMMAPrefetch : Pass<"tritongpu-wgmma-prefetch", "mlir::ModuleOp"> {
223-
let summary = "prefetch for wgmma mixed precision";
224-
225-
let description = [{
226-
This pass attempts to prefetch from shared memory for mixed-precision
227-
wgmma when operand A is in the shared memory and needs to be loaded
228-
to the local registers.
229-
}];
230-
231-
let dependentDialects = [ "mlir::triton::gpu::TritonGPUDialect",
232-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
233-
"mlir::scf::SCFDialect",
234-
"mlir::arith::ArithDialect"];
235-
}
236-
237-
238-
239232
def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::ModuleOp"> {
240233
let summary = "accelerate matmul";
241234

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
5454
// Returns whether the op is a "view op", i.e. doesn't move any data
5555
bool isView(Operation *op);
5656

57+
// Returns whether the op is a "noop op", i.e. has one input and one output
58+
// and lowers to llvm as the identity function (returns the input)
59+
bool isNoop(Operation *op);
60+
5761
/* Dump Triton IR in graphviz dot format.
5862
*
5963
* You can override `onValue` and `onOperation` in a subclass to mark

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,19 @@ struct ClusterInfo {
3838
int clusterDimZ;
3939
};
4040

41-
} // namespace nvidia_gpu
42-
} // namespace triton
43-
} // namespace mlir
44-
45-
namespace mlir {
46-
4741
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass(
4842
mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr);
4943

50-
std::unique_ptr<Pass>
51-
createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);
52-
53-
std::unique_ptr<Pass> createTritonNvidiaGPUTMALoweringPass();
54-
55-
std::unique_ptr<Pass> createTensorMemoryAllocationPass();
56-
57-
std::unique_ptr<Pass> createTritonNvidiaGPUMMALoweringPass();
58-
59-
std::unique_ptr<Pass> createTritonNvidiaGPUPromoteLHSToTMemPass();
60-
61-
std::unique_ptr<Pass> createTritonNvidiaGPURemoveTMEMTokensPass();
62-
63-
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeDescriptorEncodingPass();
64-
65-
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemLayoutsPass();
66-
67-
std::unique_ptr<Pass> createTritonNvidiaGPUInterleaveTMemPass();
44+
#define GEN_PASS_DECL
45+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
6846

6947
/// Generate the code for registering passes.
7048
#define GEN_PASS_REGISTRATION
7149
#define GEN_PASS_DECL_TRITONNVIDIAGPULEGALIZETMALAYOUTS
7250
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
7351

52+
} // namespace nvidia_gpu
53+
} // namespace triton
7454
} // namespace mlir
55+
7556
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp">
3232
and StoreLikeOps operations.
3333
}];
3434

35-
let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()";
35+
let constructor = "mlir::triton::nvidia_gpu::createTritonNvidiaGPUPlanCTAPass()";
3636

3737
let dependentDialects = [
3838
"mlir::triton::gpu::TritonGPUDialect",
@@ -48,8 +48,6 @@ def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::M
4848
properly ordered across generic and async operations.
4949
}];
5050

51-
let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()";
52-
5351
let dependentDialects = [
5452
"mlir::triton::gpu::TritonGPUDialect",
5553
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
@@ -69,22 +67,18 @@ def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::M
6967
Lower Triton experimental descriptor load to TMA load/store operations in TritonNvidiaGPUDialect.
7068
}];
7169

72-
let constructor = "mlir::createTritonNvidiaGPUTMALoweringPass()";
73-
7470
let dependentDialects = [
7571
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
7672
];
7773
}
7874

79-
def TritionTensorMemoryAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> {
75+
def TritonTensorMemoryAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> {
8076
let summary = "Assign tensor memory allocation";
8177

8278
let description = [{
8379
Decide on tensor memory allocation and assign attributes to each allocation.
8480
}];
8581

86-
let constructor = "mlir::createTensorMemoryAllocationPass()";
87-
8882
let dependentDialects = [
8983
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
9084
];
@@ -97,8 +91,6 @@ def TritonNvidiaGPUMMALoweringPass : Pass<"triton-nvidia-mma-lowering", "mlir::M
9791
Lower MMA ops to prepare for conversion to LLVM.
9892
}];
9993

100-
let constructor = "mlir::createTritonNvidiaGPUMMALoweringPass()";
101-
10294
let dependentDialects = [
10395
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
10496
];
@@ -111,8 +103,6 @@ def TritonNvidiaGPUPromoteLHSToTMemPass : Pass<"tritongpu-promote-lhs-to-tmem",
111103
Promote LHS operand of MMAv5 op to Tensor Memory.
112104
}];
113105

114-
let constructor = "mlir::createTritonNvidiaGPUPromoteLHSToTMemPass()";
115-
116106
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
117107
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
118108
"mlir::triton::TritonDialect"];

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Types.h"
6+
#include "triton/Tools/LayoutUtils.h"
67

78
using namespace mlir;
89
using namespace mlir::triton;
@@ -421,6 +422,7 @@ struct MemDescSubviewOpConversion
421422
matchAndRewrite(triton::gpu::MemDescSubviewOp op, OpAdaptor adaptor,
422423
ConversionPatternRewriter &rewriter) const override {
423424
Location loc = op->getLoc();
425+
auto *ctx = op->getContext();
424426
auto b = TritonLLVMOpBuilder(loc, rewriter);
425427
auto srcTy = op.getSrc().getType();
426428
auto destTy = op.getResult().getType();
@@ -433,53 +435,42 @@ struct MemDescSubviewOpConversion
433435
llvmElemTy, rewriter);
434436
auto smemStrides = smemObj.getStrides(srcTy, loc, rewriter);
435437
SmallVector<Value> opOffsetVals = op.getOffsets();
438+
// We assume we always create a subview of the last dimensions
436439
SmallVector<Value> opSmemStrides(smemStrides.end() - opOffsetVals.size(),
437440
smemStrides.end());
441+
// Compute total offset
438442
SmallVector<Value> offsetVals;
439443
auto destRank = op.getResult().getType().getRank();
440444
auto rankReduced = srcTy.getRank() - destRank;
441445
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
442446
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
443447
}
448+
444449
Value offset;
445-
auto allocShape = srcTy.getAllocShape();
446-
auto nvmmaEnc = dyn_cast<NVMMASharedEncodingAttr>(enc);
447-
bool isSimpleSubview =
448-
(!nvmmaEnc || allocShape.take_back(destRank) == destTy.getShape() ||
449-
nvmmaEnc.getSwizzlingByteWidth() == 0);
450-
if (!isSimpleSubview) {
451-
assert(destRank >= 2 &&
452-
"Shape size should be >= 2 when using NVMMAShared encoding");
453-
auto swizzleStride = b.i32_val((nvmmaEnc.getSwizzlingByteWidth() * 8) /
454-
llvmElemTy.getIntOrFloatBitWidth());
455-
offset = b.i32_val(0);
456-
for (auto i = 0; i < opOffsetVals.size() - 2; ++i) {
457-
offset = b.add(offset, b.mul(opOffsetVals[i], opSmemStrides[i]));
458-
}
459-
// newOffset = offset - (stridedOff * swizzledStride + contigOff /
460-
// swizzledStride * tileSize + contigOff % swizzledStride)
461-
// + stridedInc * swizzledStride + contigInc / swizzledStride *
462-
// tileSize + contigInc % swizzledStride
463-
auto stridedDim = destRank - 1 - layoutOrder[0];
464-
auto contigDim = destRank - 1 - layoutOrder[1];
465-
auto stridedOff = smemObj.getOffsets()[stridedDim];
466-
auto contigOff = smemObj.getOffsets()[contigDim];
467-
auto stridedInc = offsetVals[stridedDim];
468-
auto contigInc = offsetVals[contigDim];
469-
int allocStridedDim = allocShape.size() - 1 - layoutOrder[0];
470-
auto tileSize =
471-
b.mul(b.i32_val(allocShape[allocStridedDim]), swizzleStride);
472-
offset = b.sub(offset, b.mul(stridedOff, swizzleStride));
473-
offset = b.sub(offset, b.mul(b.udiv(contigOff, swizzleStride), tileSize));
474-
offset = b.sub(offset, b.urem(contigOff, swizzleStride));
475-
offset = b.add(offset, b.mul(stridedInc, swizzleStride));
476-
offset = b.add(offset, b.mul(b.udiv(contigInc, swizzleStride), tileSize));
477-
offset = b.add(offset, b.urem(contigInc, swizzleStride));
478-
} else {
479-
// Compute the offset based on the original strides of the shared memory
480-
// object
450+
if (rankReduced || (destTy.getRank() == 1 && destTy.getDimSize(0) == 1)) {
451+
// We are splitting the pipelining dimension which may not be a power of 2
452+
// so we can't use LinearLayouts
481453
offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
454+
} else {
455+
auto dimNames = standardOutDimNames(ctx, opOffsetVals.size());
456+
SmallVector<std::pair<StringAttr, Value>> logicalOffsets;
457+
// This assumes the subviews are additive, in the sense that we can
458+
// compute the offset of one and an add it to the offset of the previous
459+
// one we computed. We check for this in the verifier.
460+
for (int i = 0; i < rankReduced; i++) {
461+
logicalOffsets.push_back({dimNames[i], b.i32_val(0)});
462+
}
463+
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
464+
logicalOffsets.push_back({dimNames[i], offsetVals[i - rankReduced]});
465+
}
466+
// The order gives us the honest-to-goodness layout rank
467+
auto srcAllocShape =
468+
srcTy.getAllocShape().take_back(getOrder(srcTy).size());
469+
auto llInv = toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
470+
offset =
471+
applyLinearLayout(loc, rewriter, llInv, logicalOffsets)[0].second;
482472
}
473+
483474
auto base = smemObj.getBase();
484475
auto elemPtrTy = base.getType();
485476
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

0 commit comments

Comments
 (0)