Skip to content

Commit b137b65

Browse files
Merge OpenAI Triton commit 2a86177 (#5021)
This PR change the Triton base from 17be6e1 to 2a86177 (Aug 19). Pass rate: 98.85%
2 parents 6f12fd4 + 3297b19 commit b137b65

File tree

82 files changed

+1591
-1157
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+1591
-1157
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ jobs:
125125
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
126126
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py
127127
cd python/test/unit
128-
pytest --capture=tee-sys -rfs -n 12 language runtime tools \
128+
pytest --capture=tee-sys -rfs -n 12 \
129+
--ignore=blackwell \
130+
--ignore=cuda \
131+
--ignore=instrumentation \
129132
--ignore=language/test_line_info.py \
130133
--ignore=test_debug.py
131134
# TODO: uncomment

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ SmallVector<T> convertType(const VecU &in) {
2323
}
2424

2525
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
26-
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
26+
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies<Int>());
2727
}
2828
template <typename VecT> auto product(const VecT &vec) {
2929
return product(llvm::ArrayRef(vec));
@@ -132,8 +132,8 @@ template <typename VecT, typename IdxT>
132132
// Is `vec` [0, 1, ..., n]? Returns true on empty list.
133133
template <typename T> bool isIota(ArrayRef<T> vec) {
134134
static_assert(std::is_integral_v<T>);
135-
for (T i = 0; i < vec.size(); ++i) {
136-
if (vec[i] != i) {
135+
for (size_t i = 0; i < vec.size(); ++i) {
136+
if (vec[i] != static_cast<T>(i)) {
137137
return false;
138138
}
139139
}

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def TTG_MemDescIndexOp : TTG_Op<"memdesc_index", [Pure, MemDescViewTrait]> {
218218

219219
let results = (outs TTG_MemDescType:$result);
220220

221-
let assemblyFormat = [{$src `,` $index attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];
221+
let assemblyFormat = [{$src `[` $index `]` attr-dict `:` qualified(type($src)) `->` qualified(type($result))}];
222222

223223
let hasVerifier = 1;
224224
}

lib/Analysis/AxisInfo.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,16 +1261,20 @@ unsigned ModuleAxisInfoAnalysis::getAlignment(Value offsetsValue,
12611261
return 1;
12621262
auto linAttr = gpu::toLinearEncoding(tensorTy);
12631263
auto order = linAttr.getOrder();
1264-
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
1265-
auto maxContig = axisInfo->getContiguity(order[0]);
12661264

1265+
auto divisibility = axisInfo->getDivisibility(order[0]);
12671266
auto elemNumBytes = std::max<unsigned>(elementBitWidth / 8, 1);
1268-
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
1267+
auto elemTy = tensorTy.getElementType();
1268+
auto maxMultiple = isa<PointerType>(elemTy)
1269+
? std::max<int64_t>(divisibility / elemNumBytes, 1)
1270+
: divisibility;
1271+
1272+
auto maxContig = axisInfo->getContiguity(order[0]);
12691273
unsigned alignment = std::min(maxMultiple, maxContig);
1270-
LDBG("getAlignment order[0] "
1271-
<< order[0] << " maxMultipleBytes = " << maxMultipleBytes
1272-
<< " maxContig = " << maxContig << " elemNumBits = " << elementBitWidth
1273-
<< " maxMultiple = " << maxMultiple << " alignment " << alignment);
1274+
LDBG("getAlignment order[0] " << order[0] << " maxContig = " << maxContig
1275+
<< " elemNumBits = " << elementBitWidth
1276+
<< " maxMultiple = " << maxMultiple
1277+
<< " alignment " << alignment);
12741278
LLVM_DEBUG({
12751279
std::string axisStr;
12761280
llvm::raw_string_ostream os(axisStr);

lib/Dialect/Triton/IR/OpInterfaces.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ LogicalResult verifyTransposeOpInterface(Operation *op) {
1313
TransposeOpInterface transposeOp = cast<TransposeOpInterface>(op);
1414
auto rank = cast<ShapedType>(transposeOp.getSrc().getType()).getRank();
1515
auto order = transposeOp.getOrder();
16-
if (rank != order.size()) {
16+
if (static_cast<size_t>(rank) != order.size()) {
1717
return op->emitError(
1818
"order must have the same size as the rank of the operand and result");
1919
}

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,6 @@ static LogicalResult verifyRegionsImpl(Op &op) {
465465
<< " arguments, but given block with "
466466
<< block.getNumArguments() << " arguments";
467467
}
468-
unsigned i = 0;
469468
const auto &blockArgTypes = block.getArgumentTypes();
470469
for (unsigned i = 0; i < numArgs; ++i) {
471470
const auto &blockArgTy = blockArgTypes[i];
@@ -966,7 +965,7 @@ LogicalResult BroadcastOp::verify() {
966965
if (srcShape.size() != resultShape.size()) {
967966
return emitError("rank of source must be same as rank of result");
968967
}
969-
for (int i = 0; i < srcShape.size(); i++) {
968+
for (size_t i = 0; i < srcShape.size(); i++) {
970969
if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) {
971970
return emitError("Different dimensions at index ")
972971
<< i << " between source and result. "
@@ -1288,7 +1287,7 @@ LogicalResult GatherOp::verify() {
12881287
if (getAxis() >= srcTy.getRank()) {
12891288
return emitOpError("gather dimension must be less than the input rank");
12901289
}
1291-
for (int dim = 0; dim < indicesTy.getRank(); ++dim) {
1290+
for (uint32_t dim = 0; dim < indicesTy.getRank(); ++dim) {
12921291
if (dim == getAxis())
12931292
continue;
12941293
if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) {

lib/Dialect/Triton/Transforms/ArithTypeConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct RewriteArithSelectOp : mlir::OpConversionPattern<mlir::arith::SelectOp> {
1313

1414
mlir::LogicalResult
1515
matchAndRewrite(mlir::arith::SelectOp op, OneToNOpAdaptor adaptor,
16-
mlir::ConversionPatternRewriter &rewriter) const {
16+
mlir::ConversionPatternRewriter &rewriter) const override {
1717
// Note we're replacing the select op with an if op because we are
1818
// converting one value into many values.
1919
auto newIf = rewriter.create<mlir::scf::IfOp>(

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
123123
: RewritePattern(ReduceOp::getOperationName(), 1, context) {}
124124

125125
LogicalResult matchAndRewrite(Operation *op,
126-
PatternRewriter &rewriter) const {
126+
PatternRewriter &rewriter) const override {
127127
auto reduceOp = llvm::dyn_cast<ReduceOp>(op);
128128
if (!reduceOp)
129129
return failure();

lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ bool LoopCSEDriver::areIterArgsEqual(int i, int j) {
5454
return false;
5555
if (llvm::is_contained(argStack, std::make_pair(i, j)))
5656
return true;
57-
BlockArgument aArg = loop.getRegionIterArg(i);
58-
BlockArgument bArg = loop.getRegionIterArg(j);
57+
5958
// First, assume the arguments are equal. This is how recursion is broken.
6059
argStack.push_back({i, j});
6160
bool result =

lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ class LoopInvariantCodeMotionPass
5555
[&](Operation *op, Region *) {
5656
// Create the new mask for load op.
5757
if (auto loadOp = dyn_cast<LoadOp>(op)) {
58-
Value mask = loadOp.getMask();
5958
IRRewriter rewriter(loopLike);
6059
Location loc = loopLike->getLoc();
6160
Value cond;

0 commit comments

Comments
 (0)