Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ namespace mlir {
namespace func {
class FuncOp;
} // namespace func
namespace memref {
class MemRefDialect;
} // namespace memref

namespace affine {
class AffineForOp;
Expand All @@ -48,6 +51,9 @@ createAffineLoopInvariantCodeMotionPass();
/// ops.
std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();

/// Creates a pass that converts some memref operators to affine operators.
std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();

/// Apply normalization transformations to affine loop-like ops. If
/// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
/// loop is replaced by its loop body).
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
}

def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
let summary = "Turn some memref operators to affine operators where supported";
let description = [{
Raise memref.load and memref.store to affine.store and affine.load, inferring
the affine map of those operators if needed. This allows passes like --affine-scalrep
to optimize those loads and stores (forwarding them or eliminating them).
They can be turned back to memref dialect ops with --lower-affine.
}];
let constructor = "mlir::affine::createRaiseMemrefToAffine()";
let dependentDialects = ["affine::AffineDialect"];
}

def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
let summary = "Simplify affine expressions in maps/sets and normalize "
"memrefs";
Expand Down
13 changes: 7 additions & 6 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,12 @@ bool mlir::affine::isValidDim(Value value) {
return isValidDim(value, getAffineScope(defOp));

// This value has to be a block argument for an op that has the
// `AffineScope` trait or for an affine.for or affine.parallel.
// `AffineScope` trait or an induction var of an affine.for or
// affine.parallel.
if (isAffineInductionVar(value))
return true;
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
isa<AffineForOp, AffineParallelOp>(parentOp));
return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
}

// Value can be used as a dimension id iff it meets one of the following
Expand All @@ -306,10 +308,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) {

auto *op = value.getDefiningOp();
if (!op) {
// This value has to be a block argument for an affine.for or an
// This value has to be an induction var for an affine.for or an
// affine.parallel.
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
return isa<AffineForOp, AffineParallelOp>(parentOp);
return isAffineInductionVar(value);
}

// Affine apply operation is ok if all of its operands are ok.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
LoopUnroll.cpp
LoopUnrollAndJam.cpp
PipelineDataTransfer.cpp
RaiseMemrefDialect.cpp
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
Expand Down
187 changes: 187 additions & 0 deletions mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements functionality to convert memref load and store ops to
// the corresponding affine ops, inferring the affine map as needed.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"

namespace mlir {
namespace affine {
#define GEN_PASS_DEF_RAISEMEMREFDIALECT
#include "mlir/Dialect/Affine/Passes.h.inc"
} // namespace affine
} // namespace mlir

#define DEBUG_TYPE "raise-memref-to-affine"

using namespace mlir;
using namespace mlir::affine;

namespace {

/// Find the index of the given value in the `dims` list,
/// and append it if it was not already in the list. The
/// dims list is a list of symbols or dimensions of the
/// affine map. Within the results of an affine map, they
/// are identified by their index, which is why we need
/// this function.
static std::optional<size_t>
findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
function_ref<bool(Value)> isValidElement) {

Value *loopIV = std::find(dims.begin(), dims.end(), value);
if (loopIV != dims.end()) {
// We found an IV that already has an index, return that index.
return {std::distance(dims.begin(), loopIV)};
}
if (isValidElement(value)) {
// This is a valid element for the dim/symbol list, push this as a
// parameter.
size_t idx = dims.size();
dims.push_back(value);
return idx;
}
return std::nullopt;
}

/// Convert a value to an affine expr if possible. Adds dims and symbols
/// if needed.
static AffineExpr toAffineExpr(Value value,
llvm::SmallVectorImpl<Value> &affineDims,
llvm::SmallVectorImpl<Value> &affineSymbols) {
using namespace matchers;
IntegerAttr::ValueType cst;
if (matchPattern(value, m_ConstantInt(&cst))) {
return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
}

Operation *definingOp = value.getDefiningOp();
if (llvm::isa_and_nonnull<arith::AddIOp>(definingOp) ||
llvm::isa_and_nonnull<arith::MulIOp>(definingOp)) {
// TODO: replace recursion with explicit stack.
// For the moment this can be tolerated as we only recurse on
// arith.addi and arith.muli, so there cannot be any infinite
// recursion. The depth of these expressions should be in most
// cases very manageable, as affine expressions should be as
// simple as `a + b * c`.
AffineExpr lhsE =
toAffineExpr(definingOp->getOperand(0), affineDims, affineSymbols);
AffineExpr rhsE =
toAffineExpr(definingOp->getOperand(1), affineDims, affineSymbols);

if (lhsE && rhsE) {
AffineExprKind kind;
if (isa<arith::AddIOp>(definingOp)) {
kind = mlir::AffineExprKind::Add;
} else {
kind = mlir::AffineExprKind::Mul;

if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) {
// This is not an affine expression, give up.
return {};
}
}
return getAffineBinaryOpExpr(kind, lhsE, rhsE);
}
return {};
}

if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
return affine::isValidSymbol(v);
})) {
return getAffineSymbolExpr(*dimIx, value.getContext());
}

if (auto dimIx = findInListOrAdd(
value, affineDims, [](Value v) { return affine::isValidDim(v); })) {

return getAffineDimExpr(*dimIx, value.getContext());
}

return {};
}

static LogicalResult
computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
llvm::SmallVectorImpl<Value> &mapArgs) {
SmallVector<AffineExpr> results;
SmallVector<Value> symbols;
SmallVector<Value> dims;

for (Value indexExpr : indices) {
AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
if (!res) {
return failure();
}
results.push_back(res);
}

map = AffineMap::get(dims.size(), symbols.size(), results, ctx);

dims.append(symbols);
mapArgs.swap(dims);
return success();
}

struct RaiseMemrefDialect
: public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {

void runOnOperation() override {
auto *ctx = &getContext();
Operation *op = getOperation();
IRRewriter rewriter(ctx);
AffineMap map;
SmallVector<Value> mapArgs;
op->walk([&](Operation *op) {
rewriter.setInsertionPoint(op);
if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {

if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
return;
}

LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");

} else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
mapArgs);
return;
}
LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise memref op: " << op << "\n");
}
});
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::affine::createRaiseMemrefToAffine() {
return std::make_unique<RaiseMemrefDialect>();
}
Loading