Skip to content

Commit b5626ae

Browse files
arnab-polymagebondhugula
authored andcommitted
[MLIR] Fix bug in the method constructing semi affine expression from flattened form
Set proper offset to the second element of the index pair when either lhs or rhs of a local expression is a dimensional identifier, so that we do not have same index values for more than one local expression. Reviewed By: springerm, hanchung Differential Revision: https://reviews.llvm.org/D137389
1 parent 05a165b commit b5626ae

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -986,18 +986,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
986986
// constant coefficient corresponding to the indices in `coefficients` map,
987987
// and affine expression corresponding to indices in `indexToExprMap` map.
988988

989-
for (unsigned j = 0; j < numDims; ++j) {
990-
if (flatExprs[j] == 0)
991-
continue;
992-
// For dimensional expressions we set the index as <position number of the
993-
// dimension, 0>, as we want dimensional expressions to appear before
994-
// symbolic ones and products of dimensional and symbolic expressions
995-
// having the dimension with the same position number.
996-
std::pair<unsigned, signed> indexEntry(j, -1);
997-
addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
998-
}
999989
// Ensure we do not have duplicate keys in `indexToExpr` map.
1000-
unsigned offset = 0;
990+
unsigned offsetSym = 0;
991+
signed offsetDim = -1;
1001992
for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1002993
if (flatExprs[j] == 0)
1003994
continue;
@@ -1006,7 +997,7 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1006997
// as we want symbolic expressions with the same positional number to
1007998
// appear after dimensional expressions having the same positional number.
1008999
std::pair<unsigned, signed> indexEntry(
1009-
j - numDims, std::max(numDims, numSymbols) + offset++);
1000+
j - numDims, std::max(numDims, numSymbols) + offsetSym++);
10101001
addEntry(indexEntry, flatExprs[j],
10111002
getAffineSymbolExpr(j - numDims, context));
10121003
}
@@ -1038,13 +1029,13 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
10381029
// constructing. When rhs is constant, we place 0 in place of keyB.
10391030
if (lhs.isa<AffineDimExpr>()) {
10401031
lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1041-
std::pair<unsigned, signed> indexEntry(lhsPos, -1);
1032+
std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
10421033
addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
10431034
expr);
10441035
} else {
10451036
lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
10461037
std::pair<unsigned, signed> indexEntry(
1047-
lhsPos, std::max(numDims, numSymbols) + offset++);
1038+
lhsPos, std::max(numDims, numSymbols) + offsetSym++);
10481039
addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
10491040
expr);
10501041
}
@@ -1066,12 +1057,23 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
10661057
lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
10671058
rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
10681059
std::pair<unsigned, signed> indexEntry(
1069-
lhsPos, std::max(numDims, numSymbols) + offset++);
1060+
lhsPos, std::max(numDims, numSymbols) + offsetSym++);
10701061
addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
10711062
}
10721063
addedToMap[it.index()] = true;
10731064
}
10741065

1066+
for (unsigned j = 0; j < numDims; ++j) {
1067+
if (flatExprs[j] == 0)
1068+
continue;
1069+
// For dimensional expressions we set the index as <position number of the
1070+
// dimension, 0>, as we want dimensional expressions to appear before
1071+
// symbolic ones and products of dimensional and symbolic expressions
1072+
// having the dimension with the same position number.
1073+
std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1074+
addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1075+
}
1076+
10751077
// Constructing the simplified semi-affine sum of product/division/mod
10761078
// expression from the flattened form in the desired sorted order of indices
10771079
// of the various individual product/division/mod expressions.

mlir/test/Dialect/Affine/simplify-structures.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,3 +557,13 @@ func.func @semiaffine_modulo(%arg0: index) -> index {
557557
// CHECK: affine.apply #[[$MAP]]()[%{{.*}}]
558558
return %a : index
559559
}
560+
561+
// -----
562+
563+
// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s2 mod 2 + (s1 floordiv 2) * 2 + ((s2 floordiv 2) * s0) * 2)>
564+
// CHECK-LABEL: func @semiaffine_modulo_dim
565+
func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> index {
566+
%a = affine.apply affine_map<(d0)[s0, s1] -> (((d0 floordiv 2) * s0 + s1 floordiv 2) * 2 + d0 mod 2)> (%arg0)[%arg1, %arg2]
567+
//CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
568+
return %a : index
569+
}

0 commit comments

Comments
 (0)