Skip to content

Commit 0fd934f

Browse files
committed
[mlir][AffineExpr] Order arguments in the commutative affine exprs
Order symbol/dim arguments by position, put dims before symbols and put constants to the right. This is to help affine simplefications.
1 parent dc1a79a commit 0fd934f

File tree

4 files changed

+62
-5
lines changed

4 files changed

+62
-5
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -784,16 +784,48 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
784784
return nullptr;
785785
}
786786

787+
static std::pair<AffineExpr, AffineExpr>
788+
orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2) {
789+
auto sym1 = dyn_cast<AffineSymbolExpr>(expr1);
790+
auto sym2 = dyn_cast<AffineSymbolExpr>(expr2);
791+
// Try to order by symbol/dim position first
792+
if (sym1 && sym2)
793+
return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
794+
: std::pair{expr2, expr1};
795+
796+
auto dim1 = dyn_cast<AffineDimExpr>(expr1);
797+
auto dim2 = dyn_cast<AffineDimExpr>(expr2);
798+
if (dim1 && dim2)
799+
return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
800+
: std::pair{expr2, expr1};
801+
802+
// Put dims before symbols
803+
if (dim1 && sym2)
804+
return {dim1, sym2};
805+
806+
if (sym1 && dim2)
807+
return {dim2, sym1};
808+
809+
// Move constants to the right
810+
if (isa<AffineConstantExpr>(expr1) && !isa<AffineConstantExpr>(expr2))
811+
return {expr2, expr1};
812+
813+
// Otherwise, keep original order
814+
return {expr1, expr2};
815+
}
816+
787817
AffineExpr AffineExpr::operator+(int64_t v) const {
788818
return *this + getAffineConstantExpr(v, getContext());
789819
}
790820
AffineExpr AffineExpr::operator+(AffineExpr other) const {
791821
if (auto simplified = simplifyAdd(*this, other))
792822
return simplified;
793823

824+
auto [lhs, rhs] = orderCommutativeArgs(*this, other);
825+
794826
StorageUniquer &uniquer = getContext()->getAffineUniquer();
795827
return uniquer.get<AffineBinaryOpExprStorage>(
796-
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
828+
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), lhs, rhs);
797829
}
798830

799831
/// Simplify a multiply expression. Return nullptr if it can't be simplified.
@@ -856,9 +888,11 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const {
856888
if (auto simplified = simplifyMul(*this, other))
857889
return simplified;
858890

891+
auto [lhs, rhs] = orderCommutativeArgs(*this, other);
892+
859893
StorageUniquer &uniquer = getContext()->getAffineUniquer();
860894
return uniquer.get<AffineBinaryOpExprStorage>(
861-
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
895+
/*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), lhs, rhs);
862896
}
863897

864898
// Unary minus, delegate to operator*.

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ func.func @test_not_trivially_true_or_false_returning_three_results() -> (index,
508508
// -----
509509

510510
// Test simplification of mod expressions.
511-
// CHECK-DAG: #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 * s1 + (s0 - s1) mod s2)>
511+
// CHECK-DAG: #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s4 + s3 + (s0 - s1) mod s2)>
512512
// CHECK-DAG: #[[$SIMPLIFIED_MOD_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 mod (s2 - s0 * s1))>
513513
// CHECK-DAG: #[[$MODULO_AND_PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s3 - (-s0 + s3) mod s2)>
514514
// CHECK-LABEL: func @semiaffine_simplification_mod
@@ -547,7 +547,7 @@ func.func @semiaffine_simplification_floordiv_and_ceildiv(%arg0: index, %arg1: i
547547

548548
// Test simplification of product expressions.
549549
// CHECK-DAG: #[[$PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 + (s0 - s1) * s2)>
550-
// CHECK-DAG: #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s2 + s2 * s0 + s3 + s3 * s0 + s3 * s1 + s4 + s4 * s1)>
550+
// CHECK-DAG: #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s2 + s0 * s3 + s1 * s3 + s1 * s4 + s2 + s3 + s4)>
551551
// CHECK-LABEL: func @semiaffine_simplification_product
552552
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index)
553553
func.func @semiaffine_simplification_product(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index) {

mlir/test/IR/affine-map.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
#map44 = affine_map<(i, j) -> (i - 2*j, j * 6 floordiv 4)>
140140

141141
// Simplifications
142-
// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d2 + d1, (d0 * s0) * 8)>
142+
// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d1 + d2, (d0 * s0) * 8)>
143143
#map45 = affine_map<(i, j, k) [N] -> (1 + i + 3 + j - 3 + k, k + 5 + j - 5, 2*i*4*N)>
144144

145145
// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (0, d1, d0 * 2, 0)>

mlir/unittests/IR/AffineExprTest.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ TEST(AffineExprTest, constantFolding) {
8484
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
8585
}
8686

87+
TEST(AffineExprTest, commutative) {
88+
MLIRContext ctx;
89+
OpBuilder b(&ctx);
90+
auto c2 = b.getAffineConstantExpr(1);
91+
auto d0 = b.getAffineDimExpr(0);
92+
auto d1 = b.getAffineDimExpr(1);
93+
auto s0 = b.getAffineSymbolExpr(0);
94+
auto s1 = b.getAffineSymbolExpr(1);
95+
96+
ASSERT_EQ(d0 * d1, d1 * d0);
97+
ASSERT_EQ(s0 + s1, s1 + s0);
98+
ASSERT_EQ(s0 * c2, c2 * s0);
99+
}
100+
87101
TEST(AffineExprTest, divisionSimplification) {
88102
MLIRContext ctx;
89103
OpBuilder b(&ctx);
@@ -147,3 +161,12 @@ TEST(AffineExprTest, simpleAffineExprFlattenerRegression) {
147161
ASSERT_TRUE(isa<AffineConstantExpr>(result));
148162
ASSERT_EQ(cast<AffineConstantExpr>(result).getValue(), 7);
149163
}
164+
165+
TEST(AffineExprTest, simplifyCommutative) {
166+
MLIRContext ctx;
167+
OpBuilder b(&ctx);
168+
auto s0 = b.getAffineSymbolExpr(0);
169+
auto s1 = b.getAffineSymbolExpr(1);
170+
171+
ASSERT_EQ(s0 * s1 - s1 * s0 + 1, 1);
172+
}

0 commit comments

Comments
 (0)