Skip to content

Commit 48b23bb

Browse files
authored
Merge OpenAI Triton commit 34a2120 (#4617)
This PR change the Triton base from 36b3473 to 34a2120 (Jun 27). Pass rate: 97.14% Please do not squash and merge this PR.
2 parents 51a925c + 2fae2c5 commit 48b23bb

File tree

38 files changed

+1108
-218
lines changed

38 files changed

+1108
-218
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,17 @@ Alternatively, follow these steps to build LLVM from source manually.
119119
Without this, every invocation of `pip install` uses a different symlink to
120120
cmake, and this forces ninja to rebuild most of the `.a` files.
121121

122-
- vscode intellisense has some difficulty figuring out how to build Triton's C++
123-
(probably because, in our build, users don't invoke cmake directly, but
124-
instead use setup.py). Teach vscode how to compile Triton as follows.
122+
- The build system creates a `compile_commands.json` file under the Triton repo
123+
directory. This file is used by VSCode IntelliSense and clangd to provide
124+
code completion and other features for C++ code.
125+
126+
If IntelliSense does not work, you can try the following steps:
125127

126128
- Do a local build. Run command `pip install -e .`
127129
- Get the full path to the `compile_commands.json` file produced by the build:
128130
`find ./build -name 'compile_commands.json' | xargs readlink -f`.
129131
You might get a full path similar to `/Users/{username}/triton/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json`
130-
- In vscode, install the
132+
- In VSCode, install the
131133
[C/C++
132134
extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode.cpptools),
133135
then open the command palette (`Shift + Command + P` on Mac, or `Shift +

include/triton/Analysis/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class ReduceOpHelper {
6464

6565
bool isReduceWithinCTA();
6666

67+
bool isAssociative();
68+
6769
private:
6870
triton::ReduceOp op;
6971
ArrayRef<int64_t> srcShape;

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -286,25 +286,10 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
286286
}
287287

288288
// ---- begin WMMA ----
289-
if (mlir::isa<AMDWmmaEncodingAttr>(dotOpEnc.getParent())) {
290-
if (dotOpEnc.getOpIdx() == 0) {
291-
const int numBanks = 32;
292-
const int bankBitWidth = 32;
293-
294-
// number of inner dimension rows per one pattern repeat
295-
int innerDimLength = shape[order[0]];
296-
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;
297-
298-
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
299-
int vecSize = ((typeWidthInBit == 16) ? 64 : 32 ) / typeWidthInBit;
300-
int maxPhase = 16 / perPhase;
301-
302-
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
303-
} else {
304-
// Do not swizzle in case k dimension is not innermost.
305-
// In this case accesses will go in different banks even without swizzling.
306-
return get(context, 1, 1, 1, order, CTALayout);
307-
}
289+
if (auto wmmaEnc = mlir::dyn_cast<AMDWmmaEncodingAttr>(dotOpEnc.getParent())) {
290+
return wmmaEnc.composeSharedLayoutForOperand(
291+
CTALayout, dotOpEnc.getOpIdx(), shape, order, dotOpEnc.getKWidth(),
292+
typeWidthInBit, needTrans);
308293
}
309294

310295

@@ -1230,6 +1215,13 @@ Row |
12301215
Type elemType, int kWidth, int kDim, int opIdx) const;
12311216
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
12321217
static SmallVector<unsigned> getMNKDimPerInstr();
1218+
1219+
// Returns a swizzled shared layout matching this WMMA layout for the
1220+
// dot operand at the given |operandIdx| with |operandShape|.
1221+
SwizzledSharedEncodingAttr composeSharedLayoutForOperand(
1222+
CTALayoutAttr ctaLayout, int operandIdx, ArrayRef<int64_t> operandShape,
1223+
ArrayRef<unsigned> sharedOrder, unsigned kWidth,
1224+
unsigned elemBitWidth, bool needTrans) const;
12331225
}];
12341226
}
12351227

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,52 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
523523
}];
524524
}
525525

526+
def TTNG_TCGen5CommitOp : TTNG_Op<"tc_gen5_commit"> {
527+
let summary = "make an mbarrier track completion of all prior async tcgen5 ops";
528+
529+
let description = [{
530+
The `ttng.tc_gen5_commit` is an asynchronous operation that makes the
531+
mbarrier object track the completion of all prior asynchronous tcgen5
532+
operations. Upon completion of all asynchronous operations, the mbarrier
533+
arrive operation is performed on the mbarrier with a count of 1.
534+
535+
If `two_ctas` is set, then the mbarrier tracks all prior operations
536+
initiated with `two_ctas` set as well. Otherwise, it tracks all prior
537+
operations initiated without `two_ctas`.
538+
539+
Note that the completion mechanisms are guaranteed to occur sequentially in
540+
the order the commit operations were issued. This means, for example:
541+
542+
```mlir
543+
ttng.tmem_copy
544+
ttng.tc_gen5_mma
545+
ttng.tc_gen5_commit %barrierA
546+
ttng.tc_gen5_commit %barrierB
547+
```
548+
549+
`%barrierA` tracks the completion of the previous TMEM copy and MMA
550+
operations, but since the commit groups are sequential, the arrive-on
551+
operation on `%barrierA` is guaranteed to be performed before the arrive-on
552+
operation on `%barrierB`, even though its commit group is empty.
553+
}];
554+
555+
let arguments = (ins
556+
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
557+
Optional<I1>:$pred,
558+
UnitAttr:$two_ctas
559+
);
560+
561+
let assemblyFormat = [{
562+
$barrier (`,` $pred^)? attr-dict `:` qualified(type($barrier))
563+
}];
564+
565+
let builders = [
566+
OpBuilder<(ins "Value":$barrier, CArg<"bool", "false">:$two_ctas), [{
567+
build($_builder, $_state, barrier, /*pred=*/Value(), two_ctas);
568+
}]>,
569+
];
570+
}
571+
526572
def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load"> {
527573
let summary = "Load a buffer from tensor memory into a distributed tensor";
528574

lib/Analysis/Utility.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,26 @@ bool ReduceOpHelper::isReduceWithinCTA() {
142142
return getCTASplitNum(srcEncoding)[axis] == 1;
143143
}
144144

145+
bool ReduceOpHelper::isAssociative() {
146+
auto dtype = srcElementTypes[0];
147+
if (!type::isFloat(dtype))
148+
return true;
149+
size_t reduce_size = srcShape[axis];
150+
if (reduce_size <= 2)
151+
return true;
152+
bool hasNoAssociativeOp = false;
153+
op.walk([&](Operation *nestedOp) -> WalkResult {
154+
if (isa<arith::AddFOp, arith::MulFOp>(nestedOp)) {
155+
// Only when the data type is float point and reduce size greater than 2,
156+
// and has addf or mulf op, we though it's a non-associative reduce.
157+
hasNoAssociativeOp = true;
158+
return WalkResult::interrupt();
159+
}
160+
return WalkResult::advance();
161+
});
162+
return !hasNoAssociativeOp;
163+
}
164+
145165
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
146166
return getEncoding().getContigPerThread()[getAxis()];
147167
}

lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ bool LoopCSEDriver::areEqualInLoop(Value a, Value b) {
9393

9494
Operation *aDef = a.getDefiningOp();
9595
Operation *bDef = b.getDefiningOp();
96+
if (cast<OpResult>(a).getResultNumber() !=
97+
cast<OpResult>(b).getResultNumber())
98+
return false;
9699
// For it to be known that the operation results have the same value, they
97100
// must be side effect free.
98101
if (!isMemoryEffectFree(aDef) || !isMemoryEffectFree(bDef))

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -427,15 +427,6 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
427427
return encoding;
428428
}
429429

430-
bool isSplitCompatible(MLIRContext *ctx, const LinearLayout &ll) {
431-
auto lastDim = ll.getNumOutDims() - 1;
432-
auto kReg = StringAttr::get(ctx, "register");
433-
auto kLastDim = StringAttr::get(ctx, "dim" + std::to_string(lastDim));
434-
auto sublayout =
435-
ll.sublayout({kReg}, {kLastDim}).removeZeroBasesAlongDim(kReg);
436-
return sublayout == LinearLayout::identity1D(2, kReg, kLastDim);
437-
}
438-
439430
LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl,
440431
LinearLayout &outLl, bool fwdInference, int axis,
441432
std::optional<Location> loc) {
@@ -2056,6 +2047,42 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getMNKDimPerInstr() {
20562047
return {16, 16, 16};
20572048
}
20582049

2050+
SwizzledSharedEncodingAttr AMDWmmaEncodingAttr::composeSharedLayoutForOperand(
2051+
CTALayoutAttr ctaLayout, int operandIdx, ArrayRef<int64_t> operandShape,
2052+
ArrayRef<unsigned> sharedOrder, unsigned kWidth, unsigned elemBitWidth,
2053+
bool needTrans) const {
2054+
int kDimIndex = operandIdx == 0 ? 1 : 0;
2055+
bool isKContig = sharedOrder[0] == kDimIndex;
2056+
2057+
if (!isKContig) {
2058+
// Do not swizzle. In this case accesses will go in different banks even
2059+
// without swizzling.
2060+
return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder,
2061+
ctaLayout);
2062+
}
2063+
2064+
// max vectorization size for ds_load is 128 bits
2065+
int vectorSize = std::min(kWidth * elemBitWidth, 128u) / elemBitWidth;
2066+
2067+
const int numBanks = 32;
2068+
const int bankBitWidth = 32;
2069+
2070+
// Number of inner dimension rows per one pattern repeat
2071+
int innerDimLength = operandShape[sharedOrder[0]];
2072+
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth;
2073+
2074+
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
2075+
// for both RDNA3 and RDNA4, the M/N dimension of wmma is 16
2076+
// This represents the max number of rows that can be accessed
2077+
// at the same time
2078+
int mDim = getMNKDimPerInstr()[0];
2079+
int maxPhase =
2080+
std::max(std::min(mDim / perPhase, innerDimLength / vectorSize), 1);
2081+
2082+
return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase,
2083+
maxPhase, sharedOrder, ctaLayout);
2084+
}
2085+
20592086
//===----------------------------------------------------------------------===//
20602087
// Mma encoding
20612088
//===----------------------------------------------------------------------===//
@@ -2659,7 +2686,9 @@ struct TritonGPUInferLayoutInterface
26592686
auto parent = enc.getParent();
26602687
auto parentLL = toLinearLayout(joinedShape, parent);
26612688

2662-
if (isSplitCompatible(ctx, parentLL)) {
2689+
Attribute splitEnc;
2690+
auto result = inferSplitOpEncoding(parent, splitEnc, joinedShape, loc);
2691+
if (succeeded(result) && areLayoutsEquivalent(shape, splitEnc, srcEnc)) {
26632692
dstEnc = parent;
26642693
return success();
26652694
}
@@ -2709,28 +2738,16 @@ struct TritonGPUInferLayoutInterface
27092738
inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc,
27102739
ArrayRef<int64_t> shape,
27112740
std::optional<Location> loc) const override {
2741+
// SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
2742+
// shape AxBxC. The input must have 2 elements per thread in the last
2743+
// dimension, which must be the fastest running dimension. The result
2744+
// encoding is the same as the input, but with the last dimension removed.
27122745
auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
2713-
if (enc) {
2714-
// SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
2715-
// shape AxBxC. The input must have 2 elements per thread in the last
2716-
// dimension, which must be the fastest running dimension. The result
2717-
// encoding is the same as the input, but with the last dimension removed.
2718-
if (enc.getSizePerThread().back() != 2) {
2719-
return emitOptionalError(
2720-
loc, "SplitOp requires 2 elements per thread in the "
2721-
"last dimension of the input");
2722-
}
2723-
if (enc.getThreadsPerWarp().back() != 1 ||
2724-
enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) {
2725-
return emitOptionalError(
2726-
loc, "SplitOp requires threadsPerWarp, warpsPerCTA, "
2727-
"and CTAsPerCGA = 1 for the last dimension of the input");
2728-
}
2729-
if (enc.getCTALayout().getCTAsPerCGA().back() != 1) {
2730-
return emitOptionalError(
2731-
loc,
2732-
"SplitOp requires the last dimension to be most-minor in CTAOrder");
2733-
}
2746+
bool isSimpleSplit = (enc && (enc.getSizePerThread().back() == 2) &&
2747+
(enc.getThreadsPerWarp().back() == 1) &&
2748+
(enc.getWarpsPerCTA().back() == 1) &&
2749+
(enc.getCTAsPerCGA().back() == 1));
2750+
if (isSimpleSplit) {
27342751
SmallVector<unsigned> newOrder(enc.getOrder());
27352752
int splitDim = newOrder.size() - 1;
27362753
// Remove splitDim from order.

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,13 @@ void LayoutRematerialization::backwardRematerialization(
12291229
// Reduce op introduce much cost.
12301230
auto reduceOp = dyn_cast<ReduceOp>(op);
12311231
ReduceOpHelper helper(reduceOp);
1232+
if (!helper.isAssociative()) {
1233+
// We shouldn't rematerize a no associative reduce op if it has multiple
1234+
// use chain.
1235+
LDBG(" skipped rematerialization due to non-associative reduce in the "
1236+
"slice");
1237+
return;
1238+
}
12321239
rematerialisationCost += helper.getIntraWarpSizeWithUniqueData();
12331240
rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData();
12341241
}

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,10 @@ bool comesFromLoadOrBlockArg(Value v) {
15731573
v = cvtOp.getSrc();
15741574
continue;
15751575
}
1576+
if (auto transOp = dyn_cast<tt::TransOp>(def)) {
1577+
v = transOp.getSrc();
1578+
continue;
1579+
}
15761580
if (def->hasTrait<OpTrait::MemDescViewTrait>()) {
15771581
v = def->getOperand(0);
15781582
continue;

python/src/gluon_ir.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ void init_gluon_ir(py::module &&m) {
448448
pred, two_ctas, mbarriers,
449449
mbarrier_preds);
450450
})
451+
.def("create_tcgen05_commit",
452+
[](GluonOpBuilder &self, Value &barrier) {
453+
self.create<ttng::TCGen5CommitOp>(barrier);
454+
})
451455

452456
.def("create_async_tma_copy_global_to_local",
453457
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &coord,

0 commit comments

Comments
 (0)