Skip to content

Commit e1380d7

Browse files
authored
[AMD] Fix program id/count range in RangeAnalysis (#8103)
The program id/count starts from 0 albeit it is a signed integer. This change fixes some places where don't take into account.
1 parent 4e7dc91 commit e1380d7

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

test/TritonGPU/amd/amd-range-analysis.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module attributes {"ttg.num-warps" = 4 : i32} {
3535
tt.func @assumepid(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
3636
%c0 = arith.constant 0 : i32
3737
%c1024_i32 = arith.constant 1024 : i32
38-
// expected-remark@+2 {{unsigned : [0, 2147483647] signed : [0, 1024]}}
38+
// expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}}
3939
// expected-remark@+1 {{non-neg}}
4040
%pid = tt.get_program_id x : i32
4141
// expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}}
@@ -519,7 +519,7 @@ module attributes {"ttg.num-warps" = 4 : i32} {
519519
// expected-remark@+1 {{non-neg}}
520520
%0 = tt.get_program_id x : i32
521521
%c65535_i32 = arith.constant 65535 : i32
522-
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
522+
%cmpule_pid = arith.cmpi sle, %0, %c65535_i32 : i32
523523
llvm.intr.assume %cmpule_pid : i1
524524
// expected-remark@+2 {{unsigned : [0, 8388480] signed : [0, 8388480]}}
525525
// expected-remark@+1 {{non-neg}}

third_party/amd/lib/Analysis/RangeAnalysis.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ std::optional<ConstantIntRanges> maybeGetAssumedRange(Operation *assumption,
193193
APInt min, max;
194194
if (isSigned) {
195195
min = APInt::getSignedMinValue(bitWidth);
196+
if (llvm::isa_and_nonnull<mlir::triton::GetProgramIdOp,
197+
mlir::triton::GetNumProgramsOp>(
198+
anchor.getDefiningOp())) {
199+
min = APInt::getZero(bitWidth);
200+
} else
201+
min = APInt::getSignedMinValue(bitWidth);
196202
max = APInt::getSignedMaxValue(bitWidth);
197203
} else {
198204
min = APInt::getMinValue(bitWidth);
@@ -297,6 +303,12 @@ TritonIntegerRangeAnalysis::maybeGetAssumedRange(Value anchor) const {
297303
unsigned bitWidth = ConstantIntRanges::getStorageBitwidth(anchor.getType());
298304
assert(bitWidth > 0 && "expected non-zero bitwidth");
299305
ConstantIntRanges constIntRange = ConstantIntRanges::maxRange(bitWidth);
306+
if (llvm::isa_and_nonnull<GetProgramIdOp>(anchor.getDefiningOp())) {
307+
constIntRange = ConstantIntRanges::range(
308+
APInt::getZero(bitWidth),
309+
APInt(bitWidth, kDefaultMaxPrograms - 1, true), true);
310+
}
311+
300312
for (auto assumption : matchingAssumptions) {
301313
if (auto constIntRange_ = ::maybeGetAssumedRange(assumption, anchor))
302314
constIntRange = constIntRange.intersection(*constIntRange_);

0 commit comments

Comments
 (0)