Skip to content

Commit 8518cb3

Browse files
authored
Merge OpenAI Triton commit 449e014 (#5569)
This PR changes the Triton base from 4327b5b to 449e014 (Nov 13). Pass rate: 95.19%
2 parents 9af0685 + 2caaac0 commit 8518cb3

File tree

122 files changed

+3295
-1424
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+3295
-1424
lines changed

include/triton/Analysis/AxisInfo.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,16 +264,14 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
264264
axisinfo::CallbackType callback = nullptr)
265265
: CallGraph<AxisInfoMapT>(moduleOp) {
266266
SmallVector<FunctionOpInterface> funcs;
267-
for (auto root : getRoots()) {
268-
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
269-
// Pre-order edge walk callback
270-
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
271-
// Post-order node walk callback
272-
[&](FunctionOpInterface funcOp) {
273-
funcs.push_back(funcOp);
274-
funcMap.try_emplace(funcOp, AxisInfoMapT{});
275-
});
276-
}
267+
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
268+
// Pre-order edge walk callback
269+
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
270+
// Post-order node walk callback
271+
[&](FunctionOpInterface funcOp) {
272+
funcs.push_back(funcOp);
273+
funcMap.try_emplace(funcOp, AxisInfoMapT{});
274+
});
277275
SetVector<FunctionOpInterface> sortedFuncs(funcs.begin(), funcs.end());
278276
SymbolTableCollection symbolTable;
279277
for (auto funcOp : llvm::reverse(sortedFuncs)) {

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ bool comesFromLoadOrBlockArg(Value v);
281281
// `resultIdx`th result.
282282
SmallVector<Value> getTiedArgs(Operation *op, int resultIdx);
283283

284+
// Verifies the provided memory descriptor type used for barrier allocation
285+
LogicalResult verifyBarrierType(Operation *op,
286+
mlir::triton::gpu::MemDescType barrierType);
287+
284288
} // namespace mlir::triton
285289

286290
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

include/triton/Dialect/TritonInstrument/IR/Utility.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ constexpr int TC_THREAD_OFFSET = TMA_THREAD_OFFSET + NUM_THREADS;
1515
constexpr int TOTAL_NUM_THREADS = TC_THREAD_OFFSET + NUM_THREADS;
1616
constexpr int THREADS_BITMASK_SIZE = llvm::NextPowerOf2(TOTAL_NUM_THREADS);
1717

18+
namespace CommitKind {
19+
enum Kind { None = -1, AsyncCp = 0, Wgmma, TmaStore, NumCommitKinds };
20+
}
21+
1822
Operation *createStoreScratchMemory(OpBuilder &b, Location loc, Value alloc,
1923
Value tensor, RankedTensorType tensorType);
2024
Value createLoadScratchMemory(OpBuilder &b, Location loc, Value alloc,
@@ -63,8 +67,7 @@ struct AuxDataMap {
6367
RegionToValueMap writeTracking[numMemTypes];
6468
RegionToValueMap readVisibility[numMemTypes];
6569
RegionToValueMap readTracking[numMemTypes];
66-
RegionToValueMap asyncCpCommits;
67-
RegionToValueMap wgmmaCommits;
70+
RegionToValueMap commits[CommitKind::NumCommitKinds];
6871
RegionToValueMap lock;
6972
RegionToValueMap waiting;
7073

include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h

Lines changed: 0 additions & 14 deletions
This file was deleted.

lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ template <typename... Args> int64_t gcd(int64_t a, int64_t b, Args... args) {
2929
return gcd(std::gcd(a, b), args...);
3030
}
3131

32-
constexpr int log2Int(int64_t num) {
33-
return (num > 1) ? 1 + log2Int(num / 2) : 0;
34-
}
35-
3632
// If lhs * rhs overflows, return max value possible value for the type
3733
int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
3834
if (lhs > kMaxDivisor / rhs)
@@ -167,7 +163,6 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
167163
axisinfo::CallbackType callback = nullptr);
168164
using dataflow::SparseForwardDataFlowAnalysis<
169165
dataflow::Lattice<AxisInfo>>::getLatticeElement;
170-
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
171166

172167
LogicalResult
173168
visitOperation(Operation *op,
@@ -326,7 +321,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
326321
// with element locations:
327322
// [4, 5, 6, 7]
328323
// It is "strided contiguous" with a divisibility of 16 bytes
329-
auto rank = lhs.getRank();
330324
auto elemSize = std::max<int64_t>(
331325
1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8);
332326
rhsDivisibility = multiplyDivisor(rhs.getDivisibility(dim), elemSize);
@@ -345,7 +339,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
345339
return {lhs.getConstantValue().value() -
346340
rhs.getConstantValue().value()};
347341
} else if constexpr (std::is_same_v<OpTy, triton::AddPtrOp>) {
348-
auto rank = lhs.getRank();
349342
auto elemSize = std::max<int64_t>(
350343
1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8);
351344
auto rhsValue = rhs.getConstantValue().value() * elemSize;
@@ -379,14 +372,12 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
379372
int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs,
380373
const AxisInfo &rhs, int dim) override {
381374
auto lhsDivisibility = lhs.getDivisibility(dim);
382-
if (lhs.getContiguity(dim) > 1 &&
383-
!(rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1)) {
375+
if (lhs.getContiguity(dim) > 1 && rhs.getConstantValue() != 1) {
384376
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
385377
lhsDivisibility = 1;
386378
}
387379
auto rhsDivisibility = rhs.getDivisibility(dim);
388-
if (rhs.getContiguity(dim) > 1 &&
389-
!(lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1)) {
380+
if (rhs.getContiguity(dim) > 1 && lhs.getConstantValue() != 1) {
390381
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
391382
rhsDivisibility = 1;
392383
}
@@ -685,7 +676,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
685676
AxisInfo::DimVectorT contiguity, divisibility, constancy;
686677
std::optional<int64_t> constantValue;
687678
for (short d = 0; d < rank; ++d) {
688-
int64_t constHint = 1;
679+
int64_t constHint;
689680
if (lhsInfo.getConstantValue().has_value() &&
690681
rhsInfo.getConstantValue().has_value()) {
691682
constHint = shape[d];
@@ -907,7 +898,6 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
907898
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
908899
lhsDivisibility = 1;
909900
}
910-
auto numBits = log2Int(lhsDivisibility);
911901
return multiplyDivisor(lhsDivisibility, 1ll << shift);
912902
}
913903

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -719,31 +719,6 @@ class ScaledBlockedToMMA : public mlir::OpRewritePattern<triton::DotScaledOp> {
719719
mmaResult.newRetType, rewriter);
720720
Value newB = convertDotOperandForMMA(b, 1, minBitwidth,
721721
mmaResult.newRetType, rewriter);
722-
723-
// Compute tiles per warp for each operand
724-
auto computeTilePerWarp = [&](Value operand, int operandIdx) -> unsigned {
725-
auto operandTy = cast<RankedTensorType>(operand.getType());
726-
auto dotEncoding = dyn_cast<triton::gpu::DotOperandEncodingAttr>(
727-
operandTy.getEncoding());
728-
if (!dotEncoding)
729-
return 1;
730-
731-
const int bitWidth = operandTy.getElementType().getIntOrFloatBitWidth();
732-
const int kWidth = dotEncoding.getKWidth();
733-
auto rep = mmaResult.mmaEnc.getRepForOperand(
734-
triton::gpu::getShapePerCTA(operandTy), bitWidth, kWidth,
735-
dotEncoding.getOpIdx());
736-
737-
// repA = [repM, repK], repB = [repK, repN]
738-
// For operand A (idx 0): return rep[1] (repK)
739-
// For operand B (idx 1): return rep[2] (repN)
740-
if (operandIdx == 0) {
741-
return rep.size() >= 2 ? rep[1] : 1;
742-
} else {
743-
return rep.size() >= 3 ? rep[2] : 1;
744-
}
745-
};
746-
747722
const auto mmaWarps = mmaResult.mmaEnc.getWarpsPerCTA(); // [wM, wN]
748723
// Convert scales to Linear layout
749724
auto convertScale = [&](Value scale, int opIdx) -> Value {
@@ -808,8 +783,6 @@ class ScaledBlockedToMMAv5
808783
// operands
809784
Value a = dotOp.getA();
810785
Value b = dotOp.getB();
811-
auto oldAType = a.getType();
812-
auto oldBType = b.getType();
813786

814787
bool IsAMixedPrecFp4 = false;
815788
bool IsBMixedPrecFp4 = false;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,4 +1707,13 @@ SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
17071707
return {};
17081708
}
17091709

1710+
LogicalResult verifyBarrierType(Operation *op,
1711+
mlir::triton::gpu::MemDescType barrierType) {
1712+
if (!barrierType.getElementType().isInteger(64) ||
1713+
barrierType.getShape() != ArrayRef<int64_t>({1}))
1714+
return op->emitOpError(
1715+
"barrier allocation must be a descriptor of 1xi64 type");
1716+
return success();
1717+
}
1718+
17101719
} // namespace mlir::triton

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -823,17 +823,30 @@ SetVector<int> assignIfOpPartitions(scf::IfOp ifOp) {
823823
for (int i = 0; i < thenYieldPartitions.size(); ++i) {
824824
auto &thenIds = thenYieldPartitions[i];
825825
auto &elseIds = elseYieldPartitions[i];
826+
auto thenYieldOpnd = ifOp.thenYield()->getOperand(i);
827+
auto elseYieldOpnd = ifOp.elseYield()->getOperand(i);
828+
auto thenYieldOpndDefOp = thenYieldOpnd.getDefiningOp();
829+
auto elseYieldOpndDefOp = elseYieldOpnd.getDefiningOp();
826830

827-
if (auto yieldOpnd = ifOp.thenYield()->getOperand(i);
828-
isa<AsyncTokenType>(yieldOpnd.getType())) {
831+
if (isa<AsyncTokenType>(thenYieldOpnd.getType())) {
829832
// Heuristic: when if-op yields an async-token, the output partition of
830833
// the token is that of its producer
831-
if (ifOp.thenBlock()->findAncestorOpInBlock(*yieldOpnd.getDefiningOp())) {
834+
if (ifOp.thenBlock()->findAncestorOpInBlock(
835+
*thenYieldOpnd.getDefiningOp())) {
832836
outputPartitions.push_back(elseIds);
833837
} else {
834838
outputPartitions.push_back(thenIds);
835839
}
840+
} else if (thenYieldOpndDefOp &&
841+
thenYieldOpndDefOp->getBlock() == ifOp.thenBlock()) {
842+
// Heuristic: if yield operand is defined in then block, use its Ids
843+
outputPartitions.push_back(thenIds);
844+
} else if (elseYieldOpndDefOp &&
845+
elseYieldOpndDefOp->getBlock() == ifOp.elseBlock()) {
846+
// same for else block
847+
outputPartitions.push_back(elseIds);
836848
} else {
849+
// otherwise pick thenIds if avaialble, otherwise elseIds
837850
outputPartitions.push_back(!thenIds.empty() ? thenIds : elseIds);
838851
}
839852
}

0 commit comments

Comments
 (0)