Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H

#include "mlir/Transforms/DialectConversion.h"
#include <memory>

namespace mlir {
Expand All @@ -19,7 +20,8 @@ class RewritePatternSet;
#include "mlir/Conversion/Passes.h.inc"

/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter);
} // namespace mlir

#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
207 changes: 141 additions & 66 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "mlir/Transforms/Passes.h"

namespace mlir {
Expand All @@ -39,21 +41,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {

// Lower scf::for to emitc::for, implementing result values using
// emitc::variable's updated within the loop body.
struct ForLowering : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;
struct ForLowering : public OpConversionPattern<ForOp> {
using OpConversionPattern<ForOp>::OpConversionPattern;

LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

// Create an uninitialized emitc::variable op for each result of the given op.
template <typename T>
static SmallVector<Value> createVariablesForResults(T op,
PatternRewriter &rewriter) {
SmallVector<Value> resultVariables;

static LogicalResult
createVariablesForResults(T op, const TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
SmallVector<Value> &resultVariables) {
if (!op.getNumResults())
return resultVariables;
return success();

Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Expand All @@ -62,21 +65,23 @@ static SmallVector<Value> createVariablesForResults(T op,
rewriter.setInsertionPoint(op);

for (OpResult result : op.getResults()) {
Type resultType = result.getType();
Type resultType = typeConverter->convertType(result.getType());
if (!resultType)
return rewriter.notifyMatchFailure(op, "result type conversion failed");
Type varType = emitc::LValueType::get(resultType);
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
emitc::VariableOp var =
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
resultVariables.push_back(var);
}

return resultVariables;
return success();
}

// Create a series of assign ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
static void assignValues(ValueRange values, ValueRange variables,
ConversionPatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
}
Expand All @@ -89,46 +94,58 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
});
}

static void lowerYield(SmallVector<Value> &resultVariables,
PatternRewriter &rewriter, scf::YieldOp yield) {
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
scf::YieldOp yield) {
Location loc = yield.getLoc();
ValueRange operands = yield.getOperands();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);

assignValues(operands, resultVariables, rewriter, loc);
SmallVector<Value> yieldOperands;
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
}

assignValues(yieldOperands, resultVariables, rewriter, loc);

rewriter.create<emitc::YieldOp>(loc);
rewriter.eraseOp(yield);

return success();
}

// Lower the contents of an scf::if/scf::index_switch regions to an
// emitc::if/emitc::switch region. The contents of the lowering region is
// moved into the respective lowered region, but the scf::yield is replaced not
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
// set the yielded values into the result variables.
static void lowerRegion(SmallVector<Value> &resultVariables,
PatternRewriter &rewriter, Region &region,
Region &loweredRegion) {
static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
Region &region, Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
return lowerYield(op, resultVariables, rewriter,
cast<scf::YieldOp>(terminator));
}

LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
LogicalResult
ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = forOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
SmallVector<Value> resultVariables =
createVariablesForResults(forOp, rewriter);
SmallVector<Value> resultVariables;
if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
resultVariables)))
return rewriter.notifyMatchFailure(forOp,
"create variables for results failed");

assignValues(forOp.getInits(), resultVariables, rewriter, loc);
assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);

emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());

Block *loweredBody = loweredFor.getBody();

Expand All @@ -143,13 +160,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,

rewriter.restoreInsertionPoint(ip);

// Convert the original region types into the new types by adding unrealized
// casts in the beginning of the loop. This performs the conversion in place.
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
*getTypeConverter(), nullptr))) {
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
}

// Register the replacements for the block arguments and inline the body of
// the scf.for loop into the body of the emitc::for loop.
Block *scfBody = &(forOp.getRegion().front());
SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);

rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
lowerYield(resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));
auto result = lowerYield(forOp, resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));

if (failed(result)) {
return result;
}

// Load variables into SSA values after the for loop.
SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
Expand All @@ -160,38 +191,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,

// Lower scf::if to emitc::if, implementing result values as emitc::variable's
// updated within the then and else regions.
struct IfLowering : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
struct IfLowering : public OpConversionPattern<IfOp> {
using OpConversionPattern<IfOp>::OpConversionPattern;

LogicalResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

} // namespace

LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
LogicalResult
IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = ifOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the then & else regions.
SmallVector<Value> resultVariables =
createVariablesForResults(ifOp, rewriter);

Region &thenRegion = ifOp.getThenRegion();
Region &elseRegion = ifOp.getElseRegion();
SmallVector<Value> resultVariables;
if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
resultVariables)))
return rewriter.notifyMatchFailure(ifOp,
"create variables for results failed");

// Utility function to lower the contents of an scf::if region to an emitc::if
// region. The contents of the scf::if regions is moved into the respective
// emitc::if regions, but the scf::yield is replaced not only with an
// emitc::yield, but also with a sequence of emitc::assign ops that set the
// yielded values into the result variables.
auto lowerRegion = [&resultVariables, &rewriter,
&ifOp](Region &region, Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
auto result = lowerYield(ifOp, resultVariables, rewriter,
cast<scf::YieldOp>(terminator));
if (failed(result)) {
return result;
}
return success();
};

Region &thenRegion = adaptor.getThenRegion();
Region &elseRegion = adaptor.getElseRegion();

bool hasElseBlock = !elseRegion.empty();

auto loweredIf =
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);

Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);
auto result = lowerRegion(thenRegion, loweredThenRegion);
if (failed(result)) {
return result;
}

if (hasElseBlock) {
Region &loweredElseRegion = loweredIf.getElseRegion();
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
auto result = lowerRegion(elseRegion, loweredElseRegion);
if (failed(result)) {
return result;
}
}

rewriter.setInsertionPointAfter(ifOp);
Expand All @@ -203,37 +262,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,

// Lower scf::index_switch to emitc::switch, implementing result values as
// emitc::variable's updated within the case and default regions.
struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

LogicalResult
IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
PatternRewriter &rewriter) const {
LogicalResult IndexSwitchOpLowering::matchAndRewrite(
IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = indexSwitchOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the case and default regions.
SmallVector<Value> resultVariables =
createVariablesForResults(indexSwitchOp, rewriter);
SmallVector<Value> resultVariables;
if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
rewriter, resultVariables))) {
return rewriter.notifyMatchFailure(indexSwitchOp,
"create variables for results failed");
}

auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
indexSwitchOp.getNumCases());
loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());

// Lowering all case regions.
for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
loweredSwitch.getCaseRegions())) {
lowerRegion(resultVariables, rewriter, std::get<0>(pair),
std::get<1>(pair));
for (auto pair :
llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
*std::get<0>(pair), std::get<1>(pair)))) {
return failure();
}
}

// Lowering default region.
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
loweredSwitch.getDefaultRegion());
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
adaptor.getDefaultRegion(),
loweredSwitch.getDefaultRegion()))) {
return failure();
}

rewriter.setInsertionPointAfter(indexSwitchOp);
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
Expand All @@ -242,15 +310,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
return success();
}

void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ForLowering>(patterns.getContext());
patterns.add<IfLowering>(patterns.getContext());
patterns.add<IndexSwitchOpLowering>(patterns.getContext());
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ForLowering>(typeConverter, patterns.getContext());
patterns.add<IfLowering>(typeConverter, patterns.getContext());
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
}

void SCFToEmitCPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateSCFToEmitCConversionPatterns(patterns);
TypeConverter typeConverter;
// Fallback converter
// See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
// Type converters are called most to least recently inserted
typeConverter.addConversion([](Type t) { return t; });
populateEmitCSizeTTypeConversions(typeConverter);
populateSCFToEmitCConversionPatterns(patterns, typeConverter);

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
Expand Down
Loading
Loading