Skip to content

Commit 0b5c483

Browse files
authored
[AMD] Fixed pid range analysis assumption (#7793)
Fixes a bug in RangeAnalysis where the assumptions about the max number of programs were wrong for the X dimension. This is the correct information based on rocminfo. ``` Grid Max Size: 4294967295(0xffffffff) Grid Max Size per Dimension: x 2147483647(0x7fffffff) y 65535(0xffff) z 65535(0xffff) ``` This was leading to an IMA in inductor generated code when it generated a 1D grid of 72,000 programs.
1 parent 48d46bb commit 0b5c483

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

test/TritonGPU/amd/amd-range-analysis.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,18 @@ module attributes {"ttg.num-warps" = 4 : i32} {
99
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
1010
// expected-remark@+1 {{non-neg}}
1111
%0 = tt.get_program_id x : i32
12+
%c65535_i32 = arith.constant 65535 : i32
13+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
14+
llvm.intr.assume %cmpule_pid : i1
1215
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
1316
// expected-remark@+1 {{non-neg}}
1417
%1 = arith.muli %0, %c1024_i32 : i32
1518
// expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}}
1619
// expected-remark@+1 {{non-neg}}
1720
%numps = tt.get_num_programs x : i32
21+
%c65536_i32 = arith.constant 65536 : i32
22+
%cmpule_programs = arith.cmpi ule, %numps, %c65536_i32 : i32
23+
llvm.intr.assume %cmpule_programs : i1
1824
%2 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
1925
%3 = tt.splat %2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
2026
%4 = tt.load %3 : tensor<1024x!tt.ptr<f32>>
@@ -59,6 +65,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
5965
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
6066
// expected-remark@+1 {{non-neg}}
6167
%0 = tt.get_program_id x : i32
68+
%c65535_i32 = arith.constant 65535 : i32
69+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
70+
llvm.intr.assume %cmpule_pid : i1
6271
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
6372
// expected-remark@+1 {{non-neg}}
6473
%1 = arith.muli %0, %c1024_i32 : i32
@@ -80,6 +89,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
8089
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
8190
// expected-remark@+1 {{non-neg}}
8291
%0 = tt.get_program_id x : i32
92+
%c65535_i32 = arith.constant 65535 : i32
93+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
94+
llvm.intr.assume %cmpule_pid : i1
8395
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
8496
// expected-remark@+1 {{non-neg}}
8597
%1 = arith.muli %0, %c1024_i32 : i32
@@ -111,6 +123,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
111123
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
112124
// expected-remark@+1 {{non-neg}}
113125
%0 = tt.get_program_id x : i32
126+
%c65535_i32 = arith.constant 65535 : i32
127+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
128+
llvm.intr.assume %cmpule_pid : i1
114129
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
115130
// expected-remark@+1 {{non-neg}}
116131
%1 = arith.muli %0, %c1024_i32 : i32
@@ -139,6 +154,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
139154
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
140155
// expected-remark@+1 {{non-neg}}
141156
%0 = tt.get_program_id x : i32
157+
%c65535_i32 = arith.constant 65535 : i32
158+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
159+
llvm.intr.assume %cmpule_pid : i1
142160
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
143161
// expected-remark@+1 {{non-neg}}
144162
%1 = arith.muli %0, %c1024_i32 : i32
@@ -191,6 +209,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
191209
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
192210
// expected-remark@+1 {{non-neg}}
193211
%0 = tt.get_program_id x : i32
212+
%c65535_i32 = arith.constant 65535 : i32
213+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
214+
llvm.intr.assume %cmpule_pid : i1
194215
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
195216
// expected-remark@+1 {{non-neg}}
196217
%1 = arith.muli %0, %c1024_i32 : i32
@@ -239,6 +260,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
239260
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
240261
// expected-remark@+1 {{non-neg}}
241262
%0 = tt.get_program_id x : i32
263+
%c65535_i32 = arith.constant 65535 : i32
264+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
265+
llvm.intr.assume %cmpule_pid : i1
242266
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
243267
// expected-remark@+1 {{non-neg}}
244268
%1 = arith.muli %0, %c1024_i32 : i32
@@ -293,6 +317,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
293317
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
294318
// expected-remark@+1 {{non-neg}}
295319
%0 = tt.get_program_id x : i32
320+
%c65535_i32 = arith.constant 65535 : i32
321+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
322+
llvm.intr.assume %cmpule_pid : i1
296323
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
297324
// expected-remark@+1 {{non-neg}}
298325
%1 = arith.muli %0, %c1024_i32 : i32
@@ -341,6 +368,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
341368
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
342369
// expected-remark@+1 {{non-neg}}
343370
%0 = tt.get_program_id x : i32
371+
%c65535_i32 = arith.constant 65535 : i32
372+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
373+
llvm.intr.assume %cmpule_pid : i1
344374
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
345375
// expected-remark@+1 {{non-neg}}
346376
%1 = arith.muli %0, %c1024_i32 : i32
@@ -378,6 +408,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
378408
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
379409
// expected-remark@+1 {{non-neg}}
380410
%0 = tt.get_program_id x : i32
411+
%c65535_i32 = arith.constant 65535 : i32
412+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
413+
llvm.intr.assume %cmpule_pid : i1
381414
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
382415
// expected-remark@+1 {{non-neg}}
383416
%1 = arith.muli %0, %c1024_i32 : i32
@@ -416,6 +449,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
416449
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
417450
// expected-remark@+1 {{non-neg}}
418451
%0 = tt.get_program_id x : i32
452+
%c65535_i32 = arith.constant 65535 : i32
453+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
454+
llvm.intr.assume %cmpule_pid : i1
419455
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
420456
// expected-remark@+1 {{non-neg}}
421457
%1 = arith.muli %0, %c1024_i32 : i32
@@ -439,6 +475,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
439475
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
440476
// expected-remark@+1 {{non-neg}}
441477
%0 = tt.get_program_id x : i32
478+
%c65535_i32 = arith.constant 65535 : i32
479+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
480+
llvm.intr.assume %cmpule_pid : i1
442481
// expected-remark@+2 {{unsigned : [0, 16776960] signed : [0, 16776960]}}
443482
// expected-remark@+1 {{non-neg}}
444483
%1 = arith.muli %0, %c256_i32 : i32
@@ -479,6 +518,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
479518
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
480519
// expected-remark@+1 {{non-neg}}
481520
%0 = tt.get_program_id x : i32
521+
%c65535_i32 = arith.constant 65535 : i32
522+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
523+
llvm.intr.assume %cmpule_pid : i1
482524
// expected-remark@+2 {{unsigned : [0, 8388480] signed : [0, 8388480]}}
483525
// expected-remark@+1 {{non-neg}}
484526
%1 = arith.muli %0, %c128_i32 : i32
@@ -517,6 +559,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
517559
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
518560
// expected-remark@+1 {{non-neg}}
519561
%0 = tt.get_program_id x : i32
562+
%c65535_i32 = arith.constant 65535 : i32
563+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
564+
llvm.intr.assume %cmpule_pid : i1
520565
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
521566
// expected-remark@+1 {{non-neg}}
522567
%1 = arith.muli %0, %c1024_i32 : i32
@@ -550,6 +595,9 @@ module attributes {"ttg.num-ctas" = 1 : i32} {
550595
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
551596
// expected-remark@+1 {{non-neg}}
552597
%0 = tt.get_program_id x : i32
598+
%c65535_i32 = arith.constant 65535 : i32
599+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
600+
llvm.intr.assume %cmpule_pid : i1
553601
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
554602
// expected-remark@+1 {{non-neg}}
555603
%1 = arith.muli %0, %c1024_i32 : i32
@@ -577,6 +625,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
577625
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
578626
// expected-remark@+1 {{non-neg}}
579627
%0 = tt.get_program_id x : i32
628+
%c65535_i32 = arith.constant 65535 : i32
629+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
630+
llvm.intr.assume %cmpule_pid : i1
580631
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
581632
%2 = tt.addptr %arg0, %0 : !tt.ptr<f32>, i32
582633
// expected-remark@+2 {{unsigned : [0, 1024] signed : [0, 1024]}}
@@ -685,6 +736,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
685736
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
686737
// expected-remark@+1 {{non-neg}}
687738
%0 = tt.get_program_id x : i32
739+
%c65535_i32 = arith.constant 65535 : i32
740+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
741+
llvm.intr.assume %cmpule_pid : i1
688742
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
689743
// expected-remark@+1 {{non-neg}}
690744
%1 = arith.muli %0, %c1024_i32 : i32
@@ -742,6 +796,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
742796
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
743797
// expected-remark@+1 {{non-neg}}
744798
%0 = tt.get_program_id x : i32
799+
%c65535_i32 = arith.constant 65535 : i32
800+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
801+
llvm.intr.assume %cmpule_pid : i1
745802
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
746803
// expected-remark@+1 {{non-neg}}
747804
%1 = arith.muli %0, %c1024_i32 : i32
@@ -821,6 +878,9 @@ module attributes {"ttg.num-warps" = 4 : i32} {
821878
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
822879
// expected-remark@+1 {{non-neg}}
823880
%0 = tt.get_program_id x : i32
881+
%c65535_i32 = arith.constant 65535 : i32
882+
%cmpule_pid = arith.cmpi ule, %0, %c65535_i32 : i32
883+
llvm.intr.assume %cmpule_pid : i1
824884
// expected-remark@+2 {{unsigned : [0, 67107840] signed : [0, 67107840]}}
825885
// expected-remark@+1 {{non-neg}}
826886
%1 = arith.muli %0, %c1024_i32 : i32
@@ -1230,12 +1290,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12301290
// expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}}
12311291
// expected-remark@+1 {{non-neg}}
12321292
%0 = tt.get_num_programs x : i32
1293+
%c65536_i32 = arith.constant 65536 : i32
1294+
%cmpule_num_program0 = arith.cmpi ule, %0, %c65536_i32 : i32
1295+
llvm.intr.assume %cmpule_num_program0 : i1
12331296
// expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}}
12341297
// expected-remark@+1 {{non-neg}}
12351298
%1 = tt.get_num_programs y : i32
1299+
%cmpule_num_program1 = arith.cmpi ule, %1, %c65536_i32 : i32
1300+
llvm.intr.assume %cmpule_num_program1 : i1
12361301
// expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}}
12371302
// expected-remark@+1 {{non-neg}}
12381303
%2 = tt.get_num_programs z : i32
1304+
%cmpule_num_program2 = arith.cmpi ule, %2, %c65536_i32 : i32
1305+
llvm.intr.assume %cmpule_num_program2 : i1
12391306
// expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}}
12401307
// expected-remark@+1 {{non-neg}}
12411308
%3 = arith.minsi %0, %1 : i32
@@ -1564,12 +1631,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
15641631
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
15651632
// expected-remark@+1 {{non-neg}}
15661633
%0 = tt.get_program_id x : i32
1634+
%c65535_i32 = arith.constant 65535 : i32
1635+
%cmpule_pid0 = arith.cmpi ule, %0, %c65535_i32 : i32
1636+
llvm.intr.assume %cmpule_pid0 : i1
15671637
// expected-remark@+2 {{unsigned : [0, 65535] signed : [0, 65535]}}
15681638
// expected-remark@+1 {{non-neg}}
15691639
%1 = tt.get_program_id y : i32
1640+
%cmpule_pid1 = arith.cmpi ule, %1, %c65535_i32 : i32
1641+
llvm.intr.assume %cmpule_pid1 : i1
15701642
// expected-remark@+2 {{unsigned : [0, 65536] signed : [0, 65536]}}
15711643
// expected-remark@+1 {{non-neg}}
15721644
%2 = tt.get_num_programs y : i32
1645+
%c65536_i32 = arith.constant 65536 : i32
1646+
%cmpule_num_program1 = arith.cmpi ule, %2, %c65536_i32 : i32
1647+
llvm.intr.assume %cmpule_num_program1 : i1
15731648
// expected-remark@+2 {{unsigned : [0, 2097120] signed : [0, 2097120]}}
15741649
// expected-remark@+1 {{non-neg}}
15751650
%3 = arith.muli %0, %c32_i32 : i32

third_party/amd/lib/Analysis/RangeAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ triton::AMD::TritonIntegerRangeAnalysis::maybeGetTripCount(
8787
namespace {
8888

8989
constexpr int64_t kDefaultMaxTripCount = 1024;
90-
constexpr int64_t kDefaultMaxPrograms = 2 << 15; // 65536
90+
constexpr uint64_t kDefaultMaxPrograms = 1L << 31; // 2147483648
9191

9292
void getEnclosingLoops(Operation &op, SmallVector<LoopLikeOpInterface> &ops) {
9393
Operation *currOp = op.getParentOp();

0 commit comments

Comments
 (0)