Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 104 additions & 1 deletion mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1393,7 +1393,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
}

def EmitC_YieldOp : EmitC_Op<"yield",
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
[Pure, Terminator, ParentOneOf<["DoOp", "ExpressionOp", "ForOp", "IfOp", "SwitchOp"]>]> {
let summary = "Block termination operation";
let description = [{
The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
Expand Down Expand Up @@ -1727,4 +1727,107 @@ def EmitC_GetFieldOp
let hasVerifier = 1;
}

def EmitC_DoOp : EmitC_Op<"do",
[NoTerminator, OpAsmOpInterface, RecursiveMemoryEffects]> {
let summary = "Do-while operation";
let description = [{
The `emitc.do` operation represents a C/C++ do-while loop construct that
repeatedly executes a body region as long as a condition region evaluates to
true. The operation has two regions:

1. A body region that contains the loop body
2. A condition region that must yield a boolean value (i1)

The condition is evaluated before each iteration as follows:
- The condition region must contain exactly one block with:
1. An `emitc.expression` operation producing an i1 value
2. An `emitc.yield` passing through the expression result
- The expression's body contains the actual condition logic

The body region is executed before the first evaluation of the
condition. Thus, there is a guarantee that the loop will be executed
at least once. The loop terminates when the condition yields false.

The canonical structure of `emitc.do` is:

```mlir
emitc.do {
// Body region (no terminator required).
// Loop body operations...
} while {
// Condition region (must yield i1)
%condition = emitc.expression : () -> i1 {
// Condition computation...
%result = ... : i1 // Last operation must produce i1
emitc.yield %result : i1
}
// Forward expression result
emitc.yield %condition : i1
}
```

Example:

```mlir
emitc.func @do_example() {
%counter = "emitc.variable"() <{value = 0 : i32}> : () -> !emitc.lvalue<i32>
%end = emitc.literal "10" : i32
%step = emitc.literal "1" : i32

emitc.do {
// Print current value
%val = emitc.load %counter : !emitc.lvalue<i32>
emitc.verbatim "printf(\"%d\\n\", {});" args %val : i32

// Increment counter
%new_val = emitc.add %val, %step : (i32, i32) -> i32
"emitc.assign"(%counter, %new_val) : (!emitc.lvalue<i32>, i32) -> ()
} while {
%condition = emitc.expression %counter, %end : (!emitc.lvalue<i32>, i32) -> i1 {
%current = emitc.load %counter : !emitc.lvalue<i32>
%cmp_res = emitc.cmp lt, %current, %end : (i32, i32) -> i1
emitc.yield %cmp_res : i1
}
emitc.yield %condition : i1
}
return
}
```
```c++
// Code emitted for the operation above.
void do_example() {
int32_t v1 = 0;
do {
int32_t v2 = v1;
printf("%d\n", v2);
int32_t v3 = v2 + 1;
v1 = v3;
} while (v1 < 10);
return;
}
```
}];

let arguments = (ins);
let results = (outs);
let regions = (region SizedRegion<1>:$bodyRegion,
SizedRegion<1>:$conditionRegion);

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
Operation *getRootOp();

//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
//===------------------------------------------------------------------===//

/// EmitC ops in the body can omit their 'emitc.' prefix in the assembly.
static ::llvm::StringRef getDefaultDialect() {
return "emitc";
}
}];
}

#endif // MLIR_DIALECT_EMITC_IR_EMITC
174 changes: 169 additions & 5 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/LogicalResult.h"

namespace mlir {
#define GEN_PASS_DEF_SCFTOEMITC
Expand Down Expand Up @@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables,
emitc::AssignOp::create(rewriter, loc, var, value);
}

SmallVector<Value> loadValues(const SmallVector<Value> &variables,
SmallVector<Value> loadValues(ArrayRef<Value> variables,
PatternRewriter &rewriter, Location loc) {
return llvm::map_to_vector<>(variables, [&](Value var) {
Type type = cast<emitc::LValueType>(var.getType()).getValueType();
Expand All @@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,

static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
scf::YieldOp yield) {
scf::YieldOp yield, bool createYield = true) {
Location loc = yield.getLoc();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);

SmallVector<Value> yieldOperands;
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
}

assignValues(yieldOperands, resultVariables, rewriter, loc);

Expand Down Expand Up @@ -336,11 +336,174 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
return success();
}

// Lower scf::while to emitc::do using mutable variables to maintain loop state
// across iterations. The do-while structure ensures the condition is evaluated
// after each iteration, matching SCF while semantics.
struct WhileLowering : public OpConversionPattern<WhileOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = whileOp.getLoc();
MLIRContext *context = loc.getContext();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
SmallVector<Value> resultVariables;
if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
resultVariables)))
return rewriter.notifyMatchFailure(whileOp,
"Failed to create result variables");

// Create variable storage for loop-carried values to enable imperative
// updates while maintaining SSA semantics at conversion boundaries.
SmallVector<Value> loopVariables;
if (failed(createVariablesForLoopCarriedValues(
whileOp, rewriter, loopVariables, loc, context)))
return failure();

if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
rewriter, loc)))
return failure();

rewriter.setInsertionPointAfter(whileOp);

// Load the final result values from result variables.
SmallVector<Value> finalResults =
loadValues(resultVariables, rewriter, loc);
rewriter.replaceOp(whileOp, finalResults);

return success();
}

private:
// Initialize variables for loop-carried values to enable state updates
// across iterations without SSA argument passing.
LogicalResult createVariablesForLoopCarriedValues(
WhileOp whileOp, ConversionPatternRewriter &rewriter,
SmallVectorImpl<Value> &loopVars, Location loc,
MLIRContext *context) const {
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");

for (Value init : whileOp.getInits()) {
Type convertedType = getTypeConverter()->convertType(init.getType());
if (!convertedType)
return rewriter.notifyMatchFailure(whileOp, "type conversion failed");

emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
loc, emitc::LValueType::get(convertedType), noInit);
rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
loopVars.push_back(var.getResult());
}

return success();
}

// Lower scf.while to emitc.do.
LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
ArrayRef<Value> resultVars, MLIRContext *context,
ConversionPatternRewriter &rewriter,
Location loc) const {
// Create a global boolean variable to store the loop condition state.
Type i1Type = IntegerType::get(context, 1);
auto globalCondition =
rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
emitc::OpaqueAttr::get(context, ""));
Value conditionVal = globalCondition.getResult();

auto loweredDo = rewriter.create<emitc::DoOp>(loc);

// Convert region types to match the target dialect type system.
if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
*getTypeConverter(), nullptr)) ||
failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
*getTypeConverter(), nullptr))) {
return rewriter.notifyMatchFailure(whileOp,
"region types conversion failed");
}

// Prepare the before region (condition evaluation) for merging.
Block *beforeBlock = &whileOp.getBefore().front();
Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
rewriter.setInsertionPointToStart(bodyBlock);

// Load current variable values to use as initial arguments for the
// condition block.
SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);

Operation *condTerminator =
loweredDo.getBodyRegion().back().getTerminator();
scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
rewriter.setInsertionPoint(condOp);

// Update result variables with values from scf::condition.
SmallVector<Value> conditionArgs;
for (Value arg : condOp.getArgs()) {
conditionArgs.push_back(rewriter.getRemappedValue(arg));
}
assignValues(conditionArgs, resultVars, rewriter, loc);

// Convert scf.condition to condition variable assignment.
Value condition = rewriter.getRemappedValue(condOp.getCondition());
rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);

// Wrap body region in conditional to preserve scf semantics. Only create
// ifOp if after-region is non-empty.
if (whileOp.getAfterBody()->getOperations().size() > 1) {
auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);

// Prepare the after region (loop body) for merging.
Block *afterBlock = &whileOp.getAfter().front();
Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());

// Replacement values for after block using condition op arguments.
SmallVector<Value> afterReplacingValues;
for (Value arg : condOp.getArgs())
afterReplacingValues.push_back(rewriter.getRemappedValue(arg));

rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);

if (failed(lowerYield(whileOp, loopVars, rewriter,
cast<scf::YieldOp>(ifBodyBlock->getTerminator()))))
return failure();
}

rewriter.eraseOp(condOp);

// Create condition region that loads from the flag variable.
Region &condRegion = loweredDo.getConditionRegion();
Block *condBlock = rewriter.createBlock(&condRegion);
rewriter.setInsertionPointToStart(condBlock);

auto exprOp = rewriter.create<emitc::ExpressionOp>(
loc, i1Type, conditionVal, /*do_not_inline=*/false);
Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());

// Set up the expression block to load the condition variable.
exprBlock->addArgument(conditionVal.getType(), loc);
rewriter.setInsertionPointToStart(exprBlock);

// Load the condition value and yield it as the expression result.
Value cond =
rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0));
rewriter.create<emitc::YieldOp>(loc, cond);

// Yield the expression as the condition region result.
rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<emitc::YieldOp>(loc, exprOp);

return success();
}
};

void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ForLowering>(typeConverter, patterns.getContext());
patterns.add<IfLowering>(typeConverter, patterns.getContext());
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
patterns.add<WhileLowering>(typeConverter, patterns.getContext());
}

void SCFToEmitCPass::runOnOperation() {
Expand All @@ -357,7 +520,8 @@ void SCFToEmitCPass::runOnOperation() {

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
target
.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand Down
Loading