Skip to content

Commit c039b31

Browse files
MoerafaatGoogle-ML-Automation
authored andcommitted
[Triton] Fixing getLastInductionValue utility to also accept Index type. This would otherwise crash when warp specialization is enabled.
PiperOrigin-RevId: 820159796
1 parent 9338ffd commit c039b31

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

third_party/triton/temporary/series.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ those to this list.
1414
"""
1515

1616
temporary_patch_list = [
17+
"//third_party/triton:temporary/utility-fix.patch",
1718
# Add new patches just above this line
1819
]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
This patch would probably not be accepted upstream because our infrastructure
2+
uses Index type for indexing, while they use Integer type. Triton frontend
3+
wouldn't generate a situation that would run into this issue.
4+
5+
diff --git a/lib/Dialect/Triton/IR/Utility.cpp b/lib/Dialect/Triton/IR/Utility.cpp
6+
--- a/lib/Dialect/Triton/IR/Utility.cpp
7+
+++ b/lib/Dialect/Triton/IR/Utility.cpp
8+
@@ -97,8 +97,12 @@ Value tt::getLastInductionValue(OpBuilde
9+
// (ub - lb -1) // step * step + lb
10+
Value diff =
11+
b.create<arith::SubIOp>(loc, loop.getUpperBound(), loop.getLowerBound());
12+
- diff = b.create<arith::SubIOp>(
13+
- loc, diff, b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1)));
14+
+ Value one;
15+
+ if (diff.getType().isIndex())
16+
+ one = b.create<arith::ConstantIndexOp>(loc, 1);
17+
+ else
18+
+ one = b.create<arith::ConstantOp>(loc, b.getIntegerAttr(diff.getType(), 1));
19+
+ diff = b.create<arith::SubIOp>(loc, diff, one);
20+
Value ceilStep = b.create<arith::MulIOp>(
21+
loc, b.create<arith::DivSIOp>(loc, diff, loop.getStep()), loop.getStep());
22+
return b.create<arith::AddIOp>(loc, ceilStep, loop.getLowerBound());

0 commit comments

Comments
 (0)