Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
Expand Down
69 changes: 69 additions & 0 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1099,4 +1099,73 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// AffineLinearizeIndexOp
//===----------------------------------------------------------------------===//
def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
[Pure, AttrSizedOperandSegments]> {
let summary = "linearize an index";
let description = [{
The `affine.linearize_index` operation takes a sequence of index values and a
basis of the same length and linearizes the indices using that basis.

That is, for indices `%idx_1` through `%idx_N` and basis elements `b_1` through `b_N`,
it computes

```
sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
```

If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.

Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.

Example:

```mlir
%linear_index = affine.linearize_index [%index_0, %index_1, %index_2] (2, 3, 5) : index
```

In the above example, `%linear_index` conceptually holds the following:

```mlir
#map = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
%linear_index = affine.apply #map()[%index_0, %index_1, %index_2]
```
}];

let arguments = (ins Variadic<Index>:$multi_index,
Variadic<Index>:$dynamic_basis,
DenseI64ArrayAttr:$static_basis,
UnitProperty:$disjoint);
let results = (outs Index:$linear_index);

let assemblyFormat = [{
(`disjoint` $disjoint^)? ` `
`[` $multi_index `]` `by` ` `
custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
attr-dict `:` type($linear_index)
}];

let builders = [
OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "false">:$disjoint)>,
OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "false">:$disjoint)>
];

let extraClassDeclaration = [{
/// Return a vector with all the static and dynamic basis values.
SmallVector<OpFoldResult> getMixedBasis() {
OpBuilder builder(getContext());
return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
}

}];

let hasVerifier = 1;
let hasCanonicalizer = 1;
}

#endif // AFFINE_OPS
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder);
OpFoldResult linearizeIndex(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis);

/// Ensure that all operations that could be executed after `start`
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
Expand Down
16 changes: 16 additions & 0 deletions mlir/include/mlir/Interfaces/ViewLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ void printDynamicIndexList(
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
TypeRange valueTypes = TypeRange(),
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
AsmParser::Delimiter delimiter) {
return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(),
delimiter);
}
inline void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
Expand Down Expand Up @@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList(
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
inline ParseResult
parseDynamicIndexList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers,
AsmParser::Delimiter delimiter) {
DenseBoolArrayAttr scalableVals = {};
return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr,
delimiter);
}
inline ParseResult parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
Expand Down
109 changes: 109 additions & 0 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4664,6 +4664,115 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
}

//===----------------------------------------------------------------------===//
// LinearizeIndexOp
//===----------------------------------------------------------------------===//

void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex, ValueRange basis,
bool disjoint) {
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
staticBasis);
build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
}

void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex,
ArrayRef<OpFoldResult> basis,
bool disjoint) {
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
}

void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex,
ArrayRef<int64_t> basis, bool disjoint) {
build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
}

LogicalResult AffineLinearizeIndexOp::verify() {
if (getStaticBasis().empty())
return emitOpError("basis should not be empty");

if (getMultiIndex().size() != getStaticBasis().size())
return emitOpError("should be passed an index for each basis element");

auto dynamicMarkersCount =
llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
return emitOpError(
"mismatch between dynamic and static basis (kDynamic marker but no "
"corresponding dynamic basis entry) -- this can only happen due to an "
"incorrect fold/rewrite");

return success();
}

namespace {
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
/// %...d)`.

/// Note that `disjoint` is required here, because, without it, we could have
/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
/// is a valid operation where the `%c64` cannot be trivially dropped.
///
/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
/// the operation isn't asserted to be `disjoint`.
struct DropLinearizeUnitComponentsIfDisjointOrZero final
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
size_t numIndices = op.getMultiIndex().size();
SmallVector<Value> newIndices;
newIndices.reserve(numIndices);
SmallVector<OpFoldResult> newBasis;
newBasis.reserve(numIndices);

SmallVector<OpFoldResult> basis = op.getMixedBasis();
for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
if (!basisEntry || *basisEntry != 1) {
newIndices.push_back(index);
newBasis.push_back(basisElem);
continue;
}

std::optional<int64_t> indexValue = getConstantIntValue(index);
if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
newIndices.push_back(index);
newBasis.push_back(basisElem);
continue;
}
}
if (newIndices.size() == numIndices)
return failure();

if (newIndices.size() == 0) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
return success();
}
rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
op, newIndices, newBasis, op.getDisjoint());
return success();
}
};
} // namespace

void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero>(context);
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 20 additions & 1 deletion mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
Expand Down Expand Up @@ -45,6 +46,23 @@ struct LowerDelinearizeIndexOps
}
};

/// Lowers `affine.linearize_index` into a sequence of multiplications and
/// additions.
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> multiIndex =
getAsOpFoldResult(op.getMultiIndex());
OpFoldResult linearIndex =
linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
Value linearIndexValue =
getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
rewriter.replaceOp(op, linearIndexValue);
return success();
}
};

class ExpandAffineIndexOpsPass
: public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
public:
Expand All @@ -64,7 +82,8 @@ class ExpandAffineIndexOpsPass

void mlir::affine::populateAffineExpandIndexOpsPatterns(
RewritePatternSet &patterns) {
patterns.insert<LowerDelinearizeIndexOps>(patterns.getContext());
patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
patterns.getContext());
}

std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {
Expand Down
14 changes: 10 additions & 4 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,12 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder) {
return linearizeIndex(builder, builder.getLoc(), multiIndex, basis);
}

OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis) {
assert(multiIndex.size() == basis.size());
SmallVector<AffineExpr> basisAffine;
for (size_t i = 0; i < basis.size(); ++i) {
Expand All @@ -1983,13 +1989,13 @@ OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
SmallVector<OpFoldResult> strides;
strides.reserve(stridesAffine.size());
llvm::transform(stridesAffine, std::back_inserter(strides),
[&builder, &basis](AffineExpr strideExpr) {
[&builder, &basis, loc](AffineExpr strideExpr) {
return affine::makeComposedFoldedAffineApply(
builder, builder.getLoc(), strideExpr, basis);
builder, loc, strideExpr, basis);
});

auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
return affine::makeComposedFoldedAffineApply(
builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
multiIndexAndStrides);
}
17 changes: 17 additions & 0 deletions mlir/test/Conversion/AffineToStandard/lower-affine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,20 @@ func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index)
// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
// CHECK: return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
// CHECK: }

/////////////////////////////////////////////////////////////////////

func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index {
%ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index
return %ret : index
}

// CHECK-LABEL: @test_linearize_index
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
// CHECK: %[[c15:.+]] = arith.constant 15 : index
// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index
// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index
// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index
// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index
// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index
// CHECK-NEXT: return %[[ret]]
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,29 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
%1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}

// -----

// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>

// CHECK-LABEL: @linearize_static
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
// CHECK: return %[[val_0]]
func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
func.return %0 : index
}

// -----

// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>

// CHECK-LABEL: @linearize_dynamic
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, %[[arg5:.+]]: index)
// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]]
// CHECK: return %[[val_0]]
func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index {
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just curious if we even need %arg3. Would it be used if disjoint is false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it - but we also don't need the first basis element on delinearize_index either.

This is kept for symmetry with memref.load and its friends

func.return %0 : index
}
34 changes: 34 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1533,3 +1533,37 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
%2 = affine.delinearize_index %i into (%c1024) : index
return %2 : index
}

// -----

// CHECK-LABEL: @linearize_unit_basis_disjoint
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
// CHECK: return %[[ret]]
func.func @linearize_unit_basis_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> index {
%ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (3, 1, %arg3) : index
return %ret : index
}

// -----

// CHECK-LABEL: @linearize_unit_basis_zero
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index [%[[arg0]], %[[arg1]]] by (3, %[[arg2]]) : index
// CHECK: return %[[ret]]
func.func @linearize_unit_basis_zero(%arg0: index, %arg1: index, %arg2: index) -> index {
%c0 = arith.constant 0 : index
%ret = affine.linearize_index [%arg0, %c0, %arg1] by (3, 1, %arg2) : index
return %ret : index
}

// -----

// CHECK-LABEL: @linearize_all_zero_unit_basis
// CHECK: arith.constant 0 : index
// CHECK-NOT: affine.linearize_index
func.func @linearize_all_zero_unit_basis() -> index {
%c0 = arith.constant 0 : index
%ret = affine.linearize_index [%c0, %c0] by (1, 1) : index
return %ret : index
}
Loading
Loading