Skip to content

Commit e640096

Browse files
committed
[mlir][affine] Define affine.linearize_index
`affine.linearize_index` is the inverse of `affine.delinearize_index` and general useful for representing computations (like those needed to move from N-D to 1-D memrefs) that put together indices. This commit introduces `affine.linearize_index` and one simple canonicalization for it. There are plans to add `affine.linearize_index` and `affine.delinearize_index` pair canonicalizations, but we are saving those for a followup PR (especially since having llvm#113846 landed would make them nicer). Note while `affine` may not be the natural home for this operation, https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13 didn't come to any better consensus location.
1 parent e241964 commit e640096

File tree

12 files changed

+335
-5
lines changed

12 files changed

+335
-5
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
1818
#include "mlir/Dialect/Arith/IR/Arith.h"
19+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1920
#include "mlir/IR/AffineMap.h"
2021
#include "mlir/IR/Builders.h"
2122
#include "mlir/Interfaces/ControlFlowInterfaces.h"

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,4 +1099,73 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
10991099
let hasCanonicalizer = 1;
11001100
}
11011101

1102+
//===----------------------------------------------------------------------===//
1103+
// AffineLinearizeIndexOp
1104+
//===----------------------------------------------------------------------===//
1105+
def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
1106+
[Pure, AttrSizedOperandSegments]> {
1107+
let summary = "linearize an index";
1108+
let description = [{
1109+
The `affine.linearize_index` operation takes a sequence of index values and a
1110+
basis of the same length and linearizes the indices using that basis.
1111+
1112+
That is, for indices %idx_1 through %i_N and basis elements b_1 through b_N,
1113+
it computes
1114+
1115+
```
1116+
sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
1117+
```
1118+
1119+
If the `disjoint` property is present, this is an optimization hint that,
1120+
for all i, 0 <= %idx_i < B_i - that is, no index affects any other index,
1121+
except that %idx_0 may be negative to make the index as a whole negative.
1122+
1123+
Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
1124+
1125+
Example:
1126+
1127+
```
1128+
%linear_index = affine.delinearize_index [%index_0, %index_1, %index_2] (16, 224, 224) : index
1129+
```
1130+
1131+
In the above example, `%linear_index` conceptually holds the following:
1132+
1133+
```
1134+
#map = affine_map<()[s0, s1, s2] -> (s0 * 50176 + s1 * 224 + s2)>
1135+
%linear_index = affine.apply #map()[%index_0, %index_1, %index_2]
1136+
```
1137+
}];
1138+
1139+
let arguments = (ins Variadic<Index>:$multi_index,
1140+
Variadic<Index>:$dynamic_basis,
1141+
DenseI64ArrayAttr:$static_basis,
1142+
UnitProperty:$disjoint);
1143+
let results = (outs Index:$linear_index);
1144+
1145+
let assemblyFormat = [{
1146+
(`disjoint` $disjoint^)? ` `
1147+
`[` $multi_index `]` `by` ` `
1148+
custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
1149+
attr-dict `:` type($linear_index)
1150+
}];
1151+
1152+
let builders = [
1153+
OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
1154+
OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "false">:$disjoint)>,
1155+
OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "false">:$disjoint)>
1156+
];
1157+
1158+
let extraClassDeclaration = [{
1159+
/// Return a vector with all the static and dynamic basis values.
1160+
SmallVector<OpFoldResult> getMixedBasis() {
1161+
OpBuilder builder(getContext());
1162+
return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
1163+
}
1164+
1165+
}];
1166+
1167+
let hasVerifier = 1;
1168+
let hasCanonicalizer = 1;
1169+
}
1170+
11021171
#endif // AFFINE_OPS

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
316316
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
317317
ArrayRef<OpFoldResult> basis,
318318
ImplicitLocOpBuilder &builder);
319+
OpFoldResult linearizeIndex(OpBuilder &builder, Location loc,
320+
ArrayRef<OpFoldResult> multiIndex,
321+
ArrayRef<OpFoldResult> basis);
319322

320323
/// Ensure that all operations that could be executed after `start`
321324
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ void printDynamicIndexList(
109109
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
110110
TypeRange valueTypes = TypeRange(),
111111
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
112+
inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
113+
OperandRange values,
114+
ArrayRef<int64_t> integers,
115+
AsmParser::Delimiter delimiter) {
116+
return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(),
117+
delimiter);
118+
}
112119
inline void printDynamicIndexList(
113120
OpAsmPrinter &printer, Operation *op, OperandRange values,
114121
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
@@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList(
144151
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
145152
SmallVectorImpl<Type> *valueTypes = nullptr,
146153
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
154+
inline ParseResult
155+
parseDynamicIndexList(OpAsmParser &parser,
156+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
157+
DenseI64ArrayAttr &integers,
158+
AsmParser::Delimiter delimiter) {
159+
DenseBoolArrayAttr scalableVals = {};
160+
return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr,
161+
delimiter);
162+
}
147163
inline ParseResult parseDynamicIndexList(
148164
OpAsmParser &parser,
149165
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4664,6 +4664,112 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
46644664
patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
46654665
}
46664666

4667+
//===----------------------------------------------------------------------===//
4668+
// LinearizeIndexOp
4669+
//===----------------------------------------------------------------------===//
4670+
4671+
void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4672+
OperationState &odsState,
4673+
ValueRange multiIndex, ValueRange basis,
4674+
bool disjoint) {
4675+
SmallVector<Value> dynamicBasis;
4676+
SmallVector<int64_t> staticBasis;
4677+
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4678+
staticBasis);
4679+
build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4680+
}
4681+
4682+
void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4683+
OperationState &odsState,
4684+
ValueRange multiIndex,
4685+
ArrayRef<OpFoldResult> basis,
4686+
bool disjoint) {
4687+
SmallVector<Value> dynamicBasis;
4688+
SmallVector<int64_t> staticBasis;
4689+
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4690+
build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4691+
}
4692+
4693+
void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4694+
OperationState &odsState,
4695+
ValueRange multiIndex,
4696+
ArrayRef<int64_t> basis, bool disjoint) {
4697+
build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
4698+
}
4699+
4700+
LogicalResult AffineLinearizeIndexOp::verify() {
4701+
if (getStaticBasis().empty())
4702+
return emitOpError("basis should not be empty");
4703+
if (getMultiIndex().size() != getStaticBasis().size())
4704+
return emitOpError("should be passed an index for each basis element");
4705+
auto dynamicMarkersCount =
4706+
llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4707+
if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4708+
return emitOpError(
4709+
"mismatch between dynamic and static basis (kDynamic marker but no "
4710+
"corresponding dynamic basis entry) -- this can only happen due to an "
4711+
"incorrect fold/rewrite");
4712+
return success();
4713+
}
4714+
4715+
namespace {
4716+
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
4717+
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
4718+
/// %...d)`.
4719+
4720+
/// Note that `disjoint` is required here, because, without it, we could have
4721+
/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
4722+
/// is a valid operation where the `%c64` cannot be trivially dropped.
4723+
///
4724+
/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
4725+
/// the operation isn't asserted to be `disjoint`.
4726+
struct DropLinearizeUnitComponentsIfDisjointOrZero
4727+
: public OpRewritePattern<affine::AffineLinearizeIndexOp> {
4728+
using OpRewritePattern::OpRewritePattern;
4729+
4730+
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
4731+
PatternRewriter &rewriter) const override {
4732+
size_t numIndices = op.getMultiIndex().size();
4733+
SmallVector<Value> newIndices;
4734+
newIndices.reserve(numIndices);
4735+
SmallVector<OpFoldResult> newBasis;
4736+
newBasis.reserve(numIndices);
4737+
4738+
SmallVector<OpFoldResult> basis = op.getMixedBasis();
4739+
for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
4740+
std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
4741+
if (!basisEntry || *basisEntry != 1) {
4742+
newIndices.push_back(index);
4743+
newBasis.push_back(basisElem);
4744+
continue;
4745+
}
4746+
4747+
std::optional<int64_t> indexValue = getConstantIntValue(index);
4748+
if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
4749+
newIndices.push_back(index);
4750+
newBasis.push_back(basisElem);
4751+
continue;
4752+
}
4753+
}
4754+
if (newIndices.size() == numIndices)
4755+
return failure();
4756+
4757+
if (newIndices.size() == 0) {
4758+
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
4759+
return success();
4760+
}
4761+
rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
4762+
op, newIndices, newBasis, op.getDisjoint());
4763+
return success();
4764+
}
4765+
};
4766+
} // namespace
4767+
4768+
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
4769+
RewritePatternSet &patterns, MLIRContext *context) {
4770+
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero>(context);
4771+
}
4772+
46674773
//===----------------------------------------------------------------------===//
46684774
// TableGen'd op method definitions
46694775
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1616
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
1717
#include "mlir/Dialect/Affine/Utils.h"
18+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1819
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1920

2021
namespace mlir {
@@ -45,6 +46,24 @@ struct LowerDelinearizeIndexOps
4546
}
4647
};
4748

49+
/// Lowers `affine.linearize_index` into a sequence of multiplications and
50+
/// additions.
51+
struct LowerLinearizeIndexOps
52+
: public OpRewritePattern<AffineLinearizeIndexOp> {
53+
using OpRewritePattern<AffineLinearizeIndexOp>::OpRewritePattern;
54+
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
55+
PatternRewriter &rewriter) const override {
56+
SmallVector<OpFoldResult> multiIndex =
57+
getAsOpFoldResult(op.getMultiIndex());
58+
OpFoldResult linearIndex =
59+
linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
60+
Value linearIndexValue =
61+
getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
62+
rewriter.replaceOp(op, linearIndexValue);
63+
return success();
64+
}
65+
};
66+
4867
class ExpandAffineIndexOpsPass
4968
: public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
5069
public:
@@ -64,7 +83,8 @@ class ExpandAffineIndexOpsPass
6483

6584
void mlir::affine::populateAffineExpandIndexOpsPatterns(
6685
RewritePatternSet &patterns) {
67-
patterns.insert<LowerDelinearizeIndexOps>(patterns.getContext());
86+
patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
87+
patterns.getContext());
6888
}
6989

7090
std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,6 +1973,12 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
19731973
OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
19741974
ArrayRef<OpFoldResult> basis,
19751975
ImplicitLocOpBuilder &builder) {
1976+
return linearizeIndex(builder, builder.getLoc(), multiIndex, basis);
1977+
}
1978+
1979+
OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
1980+
ArrayRef<OpFoldResult> multiIndex,
1981+
ArrayRef<OpFoldResult> basis) {
19761982
assert(multiIndex.size() == basis.size());
19771983
SmallVector<AffineExpr> basisAffine;
19781984
for (size_t i = 0; i < basis.size(); ++i) {
@@ -1983,13 +1989,13 @@ OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
19831989
SmallVector<OpFoldResult> strides;
19841990
strides.reserve(stridesAffine.size());
19851991
llvm::transform(stridesAffine, std::back_inserter(strides),
1986-
[&builder, &basis](AffineExpr strideExpr) {
1992+
[&builder, &basis, loc](AffineExpr strideExpr) {
19871993
return affine::makeComposedFoldedAffineApply(
1988-
builder, builder.getLoc(), strideExpr, basis);
1994+
builder, loc, strideExpr, basis);
19891995
});
19901996

19911997
auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
19921998
OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
1993-
return affine::makeComposedFoldedAffineApply(
1994-
builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
1999+
return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
2000+
multiIndexAndStrides);
19952001
}

mlir/test/Conversion/AffineToStandard/lower-affine.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,3 +981,20 @@ func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index)
981981
// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
982982
// CHECK: return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
983983
// CHECK: }
984+
985+
/////////////////////////////////////////////////////////////////////
986+
987+
func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index {
988+
%ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index
989+
return %ret : index
990+
}
991+
992+
// CHECK-LABEL: @test_linearize_index
993+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
994+
// CHECK: %[[c15:.+]] = arith.constant 15 : index
995+
// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index
996+
// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index
997+
// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index
998+
// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index
999+
// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index
1000+
// CHECK-NEXT: return %[[ret]]

mlir/test/Dialect/Affine/affine-expand-index-ops.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,29 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
4444
%1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
4545
return %1#0, %1#1, %1#2 : index, index, index
4646
}
47+
48+
// -----
49+
50+
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
51+
52+
// CHECK-LABEL: @linearize_static
53+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
54+
// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
55+
// CHECK: return %[[val_0]]
56+
func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
57+
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
58+
func.return %0 : index
59+
}
60+
61+
// -----
62+
63+
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
64+
65+
// CHECK-LABEL: @linearize_dynamic
66+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, %[[arg5:.+]]: index)
67+
// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]]
68+
// CHECK: return %[[val_0]]
69+
func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index {
70+
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index
71+
func.return %0 : index
72+
}

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,3 +1533,37 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
15331533
%2 = affine.delinearize_index %i into (%c1024) : index
15341534
return %2 : index
15351535
}
1536+
1537+
// -----
1538+
1539+
// CHECK-LABEL: @linearize_unit_basis_disjoint
1540+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
1541+
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
1542+
// CHECK: return %[[ret]]
1543+
func.func @linearize_unit_basis_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> index {
1544+
%ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (3, 1, %arg3) : index
1545+
return %ret : index
1546+
}
1547+
1548+
// -----
1549+
1550+
// CHECK-LABEL: @linearize_unit_basis_zero
1551+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
1552+
// CHECK: %[[ret:.+]] = affine.linearize_index [%[[arg0]], %[[arg1]]] by (3, %[[arg2]]) : index
1553+
// CHECK: return %[[ret]]
1554+
func.func @linearize_unit_basis_zero(%arg0: index, %arg1: index, %arg2: index) -> index {
1555+
%c0 = arith.constant 0 : index
1556+
%ret = affine.linearize_index [%arg0, %c0, %arg1] by (3, 1, %arg2) : index
1557+
return %ret : index
1558+
}
1559+
1560+
// -----
1561+
1562+
// CHECK-LABEL: @linearize_all_zero_unit_basis
1563+
// CHECK: arith.constant 0 : index
1564+
// CHECK-NOT: affine.linearize_index
1565+
func.func @linearize_all_zero_unit_basis() -> index {
1566+
%c0 = arith.constant 0 : index
1567+
%ret = affine.linearize_index [%c0, %c0] by (1, 1) : index
1568+
return %ret : index
1569+
}

0 commit comments

Comments
 (0)