Skip to content

Commit 973461c

Browse files
authored
[ANALYSIS] Fix the divisibility of rem op with pessimistic analysis (#7441)
When `rhs` has a contiguity > 1, it's difficult to specialize conditions to get optimal divisibility estimates. For example, `[128, 128 128, ..., 128] % [0, 1, 2, 3, ...] = [0, 0, 0, 2, ...]`
1 parent e5ea25e commit 973461c

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -480,11 +480,19 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
480480

481481
int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
482482
int dim) override {
483-
// lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k''
484-
// rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p''
485-
// lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r
486-
// r must be divisible by gcd(d_lhs, d_rhs)
487-
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim));
483+
auto resTy = dyn_cast<RankedTensorType>(op.getType());
484+
if (rhs.getConstancy(dim) > 1) {
485+
// lhs: d_lhs * k = gcd(d_lhs, d_rhs) * k' * k = gcd(d_lhs, d_rhs) * k''
486+
// rhs: d_rhs * p = gcd(d_lhs, d_rhs) * p' * p = gcd(d_lhs, d_rhs) * p''
487+
// lhs = gcd(d_lhs, d_rhs) * k'' = gcd(d_lhs, d_rhs) * d + r
488+
// r must be divisible by gcd(d_lhs, d_rhs)
489+
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim));
490+
}
491+
// Otherwise we shouldn't assume any divisibility.
492+
// For example:
493+
// lhs: [2, 2, 4, 4], rhs: [0, 1, 2, 3]
494+
// lhs % rhs = [0, 0, 0, 1]
495+
return 1;
488496
};
489497

490498
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,

test/Analysis/test-alignment.mlir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,28 @@ tt.func @rem() {
190190
%4 = arith.constant dense<64> : tensor<128xi32>
191191
// expected-remark @below {{contiguity = [64], divisibility = [64], constancy = [1], constant_value = <none>}}
192192
%5 = arith.remsi %0, %4 : tensor<128xi32>
193-
// expected-remark @below {{contiguity = [1], divisibility = [64], constancy = [1], constant_value = <none>}}
193+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
194194
%6 = arith.remsi %4, %0 : tensor<128xi32>
195195
// expected-remark @below {{contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66}}
196196
%7 = arith.constant dense<66> : tensor<128xi32>
197197
// expected-remark @below {{contiguity = [2], divisibility = [2], constancy = [1], constant_value = <none>}}
198198
%8 = arith.remui %0, %7 : tensor<128xi32>
199+
// expected-remark @below {{contiguity = [1], divisibility = [64], constancy = [128], constant_value = 192}}
200+
%9 = arith.constant dense<192> : tensor<128xi32>
201+
// expected-remark @below {{contiguity = [64], divisibility = [64], constancy = [1], constant_value = <none>}}
202+
%10 = arith.remsi %0, %9 : tensor<128xi32>
203+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
204+
%11 = arith.remsi %9, %0 : tensor<128xi32>
205+
// expected-remark @below {{contiguity = [128], divisibility = [32], constancy = [1], constant_value = <none>}}
206+
%12 = tt.make_range {end = 160 : i32, start = 32 : i32} : tensor<128xi32>
207+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
208+
%13 = arith.remsi %0, %12 : tensor<128xi32>
209+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
210+
%14 = arith.remsi %12, %0 : tensor<128xi32>
211+
// expected-remark @below {{contiguity = [32], divisibility = [32], constancy = [1], constant_value = <none>}}
212+
%15 = arith.remsi %12, %4 : tensor<128xi32>
213+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
214+
%16 = arith.remsi %4, %12 : tensor<128xi32>
199215
tt.return
200216
}
201217

0 commit comments

Comments
 (0)