Skip to content

Commit ecd33fe

Browse files
authored
[BACKEND] Improve constant analysis in AxisInfo (#8502)
We fix a number of cases where the constancy analysis could be improved. The code is quite messy, and the whole pass could do with a full rewrite, but we are not doing so ATM. This PR was mostly vibecoded, with a cleaning pass afterwards from me.
1 parent bad2576 commit ecd33fe

File tree

2 files changed

+96
-57
lines changed

2 files changed

+96
-57
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,26 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
9191
auto lhsInfo = operands[0]->getValue();
9292
auto rhsInfo = operands[1]->getValue();
9393
auto rank = lhsInfo.getRank();
94+
assert(isa<RankedTensorType>(op.getType()) ||
95+
rank == 1 && "Expected ranked tensor or scalar");
9496
assert(operands.size() == 2 && "Expected two operands");
97+
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
98+
if (constantValue.has_value()) {
99+
auto resTy = dyn_cast<RankedTensorType>(op.getType());
100+
AxisInfo::DimVectorT constancy =
101+
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
102+
AxisInfo::DimVectorT contiguity(rank, 1);
103+
AxisInfo::DimVectorT divisibility(
104+
rank, highestPowOf2Divisor<int64_t>(constantValue.value()));
105+
return AxisInfo(contiguity, divisibility, constancy, constantValue);
106+
}
95107
AxisInfo::DimVectorT contiguity;
96108
AxisInfo::DimVectorT divisibility;
97109
AxisInfo::DimVectorT constancy;
98-
auto constantValue = getConstantValue(op, lhsInfo, rhsInfo);
99110
for (auto d = 0; d < rank; ++d) {
100-
if (constantValue.has_value()) {
101-
contiguity.push_back(1);
102-
constancy.push_back(
103-
std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d)));
104-
divisibility.push_back(
105-
highestPowOf2Divisor<int64_t>(constantValue.value()));
106-
} else {
107-
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
108-
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
109-
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
110-
}
111+
contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d));
112+
constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d));
113+
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
111114
}
112115
return AxisInfo(contiguity, divisibility, constancy, constantValue);
113116
}
@@ -125,9 +128,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
125128

126129
virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs,
127130
const AxisInfo &rhs, int dim) {
128-
return 1;
131+
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
129132
}
130-
131133
virtual std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
132134
const AxisInfo &rhs) {
133135
return {};
@@ -328,11 +330,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
328330
return gcd(lhs.getDivisibility(dim), rhsDivisibility);
329331
}
330332

331-
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
332-
int dim) override {
333-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
334-
}
335-
336333
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
337334
const AxisInfo &rhs) override {
338335
if (lhs.getConstantValue().has_value() &&
@@ -375,11 +372,6 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
375372
return std::max(lhsContiguity, rhsContiguity);
376373
}
377374

378-
int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs,
379-
const AxisInfo &rhs, int dim) override {
380-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
381-
}
382-
383375
int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs,
384376
const AxisInfo &rhs, int dim) override {
385377
auto lhsDivisibility = lhs.getDivisibility(dim);
@@ -399,9 +391,13 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
399391

400392
std::optional<int64_t> getConstantValue(arith::MulIOp op, const AxisInfo &lhs,
401393
const AxisInfo &rhs) override {
402-
if (lhs.getConstantValue().has_value() &&
403-
rhs.getConstantValue().has_value())
404-
return {lhs.getConstantValue().value() * rhs.getConstantValue().value()};
394+
auto lhsConst = lhs.getConstantValue();
395+
auto rhsConst = rhs.getConstantValue();
396+
if (lhsConst.has_value() && rhsConst.has_value())
397+
return {lhsConst.value() * rhsConst.value()};
398+
if ((lhsConst.has_value() && lhsConst.value() == 0) ||
399+
(rhsConst.has_value() && rhsConst.value() == 0))
400+
return 0;
405401
return {};
406402
}
407403
};
@@ -424,12 +420,11 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
424420
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
425421
int dim) override {
426422
auto resTy = dyn_cast<RankedTensorType>(op.getType());
423+
auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
427424
if (!resTy)
428-
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
425+
return constancy;
429426
auto shape = resTy.getShape();
430-
// Case 1: both lhs and rhs are constants.
431-
auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
432-
// Case 2: lhs contiguous, rhs constant.
427+
// Case: lhs contiguous, rhs constant.
433428
// lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
434429
// rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
435430
// lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p),
@@ -526,15 +521,15 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
526521

527522
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
528523
int dim) override {
524+
auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
529525
auto resTy = dyn_cast<RankedTensorType>(op.getType());
530526
if (!resTy)
531-
return BinaryOpVisitorImpl<OpTy>::getConstancy(op, lhs, rhs, dim);
532-
auto shape = resTy.getShape();
533-
// lhs % 1 = 0
534-
return rhs.getConstantValue().has_value() &&
535-
rhs.getConstantValue().value() == 1
536-
? shape[dim]
537-
: gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
527+
return constancy;
528+
// Case: lhs % 1 = 0
529+
if (rhs.getConstantValue().has_value() &&
530+
rhs.getConstantValue().value() == 1)
531+
return resTy.getDimSize(dim);
532+
return constancy;
538533
}
539534

540535
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
@@ -689,7 +684,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
689684
int64_t constHint = 1;
690685
if (lhsInfo.getConstantValue().has_value() &&
691686
rhsInfo.getConstantValue().has_value()) {
692-
constHint = lhsInfo.getConstancy(d);
687+
constHint = shape[d];
693688
constantValue =
694689
compare(getPredicate(op), lhsInfo.getConstantValue().value(),
695690
rhsInfo.getConstantValue().value())
@@ -848,6 +843,13 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
848843
rhsInfo.getConstantValue().has_value() &&
849844
lhsInfo.getConstantValue() == rhsInfo.getConstantValue())
850845
constantValue = lhsInfo.getConstantValue();
846+
847+
if (constantValue.has_value()) {
848+
auto resTy = dyn_cast<RankedTensorType>(op.getType());
849+
assert(resTy || rank == 1);
850+
constancy =
851+
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
852+
}
851853
}
852854

853855
return AxisInfo(contiguity, divisibility, constancy, constantValue);
@@ -860,11 +862,6 @@ class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
860862
using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
861863

862864
private:
863-
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
864-
int dim) override {
865-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
866-
}
867-
868865
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
869866
const AxisInfo &rhs) override {
870867
if (lhs.getConstantValue().has_value() &&
@@ -910,11 +907,6 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
910907
return multiplyDivisor(lhsDivisibility, 1ll << shift);
911908
}
912909

913-
int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs,
914-
const AxisInfo &rhs, int dim) override {
915-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
916-
}
917-
918910
std::optional<int64_t> getConstantValue(arith::ShLIOp op, const AxisInfo &lhs,
919911
const AxisInfo &rhs) override {
920912
if (lhs.getConstantValue().has_value() &&
@@ -952,11 +944,6 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
952944
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
953945
}
954946

955-
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
956-
int dim) override {
957-
return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim));
958-
}
959-
960947
std::optional<int64_t> getConstantValue(OpTy op, const AxisInfo &lhs,
961948
const AxisInfo &rhs) override {
962949
if (lhs.getConstantValue().has_value() &&
@@ -989,9 +976,15 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
989976
constantValue = {std::min(lhsInfo.getConstantValue().value(),
990977
rhsInfo.getConstantValue().value())};
991978
}
979+
auto resTy = dyn_cast<RankedTensorType>(op.getType());
980+
assert(resTy || rank == 1);
981+
AxisInfo::DimVectorT constancy =
982+
resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1);
983+
AxisInfo::DimVectorT divisibility(
984+
rank, highestPowOf2Divisor<int64_t>(constantValue.value()));
992985
return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1),
993-
/*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1),
994-
/*knownConstancy=*/AxisInfo::DimVectorT(rank, 1),
986+
/*knownDivisibility=*/divisibility,
987+
/*knownConstancy=*/constancy,
995988
/*constantValue=*/constantValue);
996989
} else {
997990
AxisInfo::DimVectorT contiguity, divisibility, constancy;

test/Analysis/test-alignment.mlir

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ tt.func @max_min() {
458458
%4 = arith.constant dense<8> : tensor<128xi32>
459459
// expected-remark @below {{contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4}}
460460
%5 = arith.constant dense<4> : tensor<128xi32>
461-
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8}}
461+
// expected-remark @below {{contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8}}
462462
%6 = arith.maxsi %4, %5 : tensor<128xi32>
463463
tt.return
464464
}
@@ -1011,3 +1011,49 @@ tt.func @caller() {
10111011
tt.call @callee(%1) : (tensor<128x1xi32>) -> ()
10121012
tt.return
10131013
}
1014+
1015+
// -----
1016+
1017+
tt.func @mul_zero_constancy() {
1018+
%range = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
1019+
%zeros = arith.constant dense<0> : tensor<128xi32>
1020+
// expected-remark @below {{constancy = [128]}}
1021+
%product = arith.muli %zeros, %range : tensor<128xi32>
1022+
tt.return
1023+
}
1024+
1025+
// -----
1026+
1027+
tt.func @max_constancy() {
1028+
%c5 = arith.constant dense<5> : tensor<4xi32>
1029+
%c7 = arith.constant dense<7> : tensor<4xi32>
1030+
// expected-remark @below {{constancy = [4], constant_value = 7}}
1031+
%max = arith.maxsi %c5, %c7 : tensor<4xi32>
1032+
tt.return
1033+
}
1034+
1035+
// -----
1036+
1037+
tt.func @select_same_value_constancy() {
1038+
%range = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
1039+
%two = arith.constant dense<2> : tensor<4xi32>
1040+
%mod = arith.remsi %range, %two : tensor<4xi32>
1041+
%zero = arith.constant dense<0> : tensor<4xi32>
1042+
%cond = arith.cmpi ne, %mod, %zero : tensor<4xi32>
1043+
%lhs = arith.constant dense<42> : tensor<4xi32>
1044+
%rhs = arith.constant dense<42> : tensor<4xi32>
1045+
// expected-remark @below {{constancy = [4], constant_value = 42}}
1046+
%sel = arith.select %cond, %lhs, %rhs : tensor<4xi1>, tensor<4xi32>
1047+
tt.return
1048+
}
1049+
1050+
// -----
1051+
1052+
tt.func @cmp_after_max_constancy() {
1053+
%c5 = arith.constant dense<5> : tensor<4xi32>
1054+
%c7 = arith.constant dense<7> : tensor<4xi32>
1055+
%max = arith.maxsi %c5, %c7 : tensor<4xi32>
1056+
// expected-remark @below {{constancy = [4], constant_value = 1}}
1057+
%cmp = arith.cmpi sgt, %max, %c5 : tensor<4xi32>
1058+
tt.return
1059+
}

0 commit comments

Comments
 (0)