Skip to content

Commit c3aa158

Browse files
[mlir][emitc] Add emitc.do op to the dialect (llvm#143008)
This patch adds: - Emission of the corresponding ops in the CppEmitter - Conversion from the SCF dialect to the EmitC dialect for the ops - Corresponding tests
1 parent 7910ed2 commit c3aa158

File tree

8 files changed

+1012
-19
lines changed

8 files changed

+1012
-19
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1393,7 +1393,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
13931393
}
13941394

13951395
def EmitC_YieldOp : EmitC_Op<"yield",
1396-
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
1396+
[Pure, Terminator, ParentOneOf<["DoOp", "ExpressionOp", "ForOp", "IfOp", "SwitchOp"]>]> {
13971397
let summary = "Block termination operation";
13981398
let description = [{
13991399
The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
@@ -1725,4 +1725,105 @@ def EmitC_GetFieldOp
17251725
let hasVerifier = 1;
17261726
}
17271727

1728+
def EmitC_DoOp : EmitC_Op<"do",
1729+
[NoTerminator, OpAsmOpInterface, RecursiveMemoryEffects]> {
1730+
let summary = "Do-while operation";
1731+
let description = [{
1732+
The `emitc.do` operation represents a C/C++ do-while loop construct that
1733+
repeatedly executes a body region as long as a condition region evaluates to
1734+
true. The operation has two regions:
1735+
1736+
1. A body region that contains the loop body
1737+
2. A condition region that must yield a boolean value (i1)
1738+
1739+
The condition is evaluated before each iteration as follows:
1740+
- The condition region must contain exactly one block with:
1741+
1. An `emitc.expression` operation producing an i1 value
1742+
2. An `emitc.yield` passing through the expression result
1743+
- The expression's body contains the actual condition logic
1744+
1745+
The body region is executed before the first evaluation of the
1746+
condition. Thus, there is a guarantee that the loop will be executed
1747+
at least once. The loop terminates when the condition yields false.
1748+
1749+
The canonical structure of `emitc.do` is:
1750+
1751+
```mlir
1752+
emitc.do {
1753+
// Body region (no terminator required).
1754+
// Loop body operations...
1755+
} while {
1756+
// Condition region (must yield i1)
1757+
%condition = emitc.expression : () -> i1 {
1758+
// Condition computation...
1759+
%result = ... : i1 // Last operation must produce i1
1760+
emitc.yield %result : i1
1761+
}
1762+
// Forward expression result
1763+
emitc.yield %condition : i1
1764+
}
1765+
```
1766+
1767+
Example:
1768+
1769+
```mlir
1770+
emitc.func @do_example() {
1771+
%counter = "emitc.variable"() <{value = 0 : i32}> : () -> !emitc.lvalue<i32>
1772+
%end = emitc.literal "10" : i32
1773+
%step = emitc.literal "1" : i32
1774+
1775+
emitc.do {
1776+
// Print current value
1777+
%val = emitc.load %counter : !emitc.lvalue<i32>
1778+
emitc.verbatim "printf(\"%d\\n\", {});" args %val : i32
1779+
1780+
// Increment counter
1781+
%new_val = emitc.add %val, %step : (i32, i32) -> i32
1782+
"emitc.assign"(%counter, %new_val) : (!emitc.lvalue<i32>, i32) -> ()
1783+
} while {
1784+
%condition = emitc.expression %counter, %end : (!emitc.lvalue<i32>, i32) -> i1 {
1785+
%current = emitc.load %counter : !emitc.lvalue<i32>
1786+
%cmp_res = emitc.cmp lt, %current, %end : (i32, i32) -> i1
1787+
emitc.yield %cmp_res : i1
1788+
}
1789+
emitc.yield %condition : i1
1790+
}
1791+
return
1792+
}
1793+
```
1794+
```c++
1795+
// Code emitted for the operation above.
1796+
void do_example() {
1797+
int32_t v1 = 0;
1798+
do {
1799+
int32_t v2 = v1;
1800+
printf("%d\n", v2);
1801+
int32_t v3 = v2 + 1;
1802+
v1 = v3;
1803+
} while (v1 < 10);
1804+
return;
1805+
}
1806+
```
1807+
}];
1808+
1809+
let arguments = (ins);
1810+
let results = (outs);
1811+
let regions = (region SizedRegion<1>:$bodyRegion,
1812+
SizedRegion<1>:$conditionRegion);
1813+
1814+
let hasCustomAssemblyFormat = 1;
1815+
let hasVerifier = 1;
1816+
1817+
let extraClassDeclaration = [{
1818+
//===------------------------------------------------------------------===//
1819+
// OpAsmOpInterface Methods
1820+
//===------------------------------------------------------------------===//
1821+
1822+
/// EmitC ops in the body can omit their 'emitc.' prefix in the assembly.
1823+
static ::llvm::StringRef getDefaultDialect() {
1824+
return "emitc";
1825+
}
1826+
}];
1827+
}
1828+
17281829
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 172 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/PatternMatch.h"
2222
#include "mlir/Transforms/DialectConversion.h"
2323
#include "mlir/Transforms/Passes.h"
24+
#include "llvm/Support/LogicalResult.h"
2425

2526
namespace mlir {
2627
#define GEN_PASS_DEF_SCFTOEMITC
@@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables,
106107
emitc::AssignOp::create(rewriter, loc, var, value);
107108
}
108109

109-
SmallVector<Value> loadValues(const SmallVector<Value> &variables,
110+
SmallVector<Value> loadValues(ArrayRef<Value> variables,
110111
PatternRewriter &rewriter, Location loc) {
111112
return llvm::map_to_vector<>(variables, [&](Value var) {
112113
Type type = cast<emitc::LValueType>(var.getType()).getValueType();
@@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
116117

117118
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
118119
ConversionPatternRewriter &rewriter,
119-
scf::YieldOp yield) {
120+
scf::YieldOp yield, bool createYield = true) {
120121
Location loc = yield.getLoc();
121122

122123
OpBuilder::InsertionGuard guard(rewriter);
123124
rewriter.setInsertionPoint(yield);
124125

125126
SmallVector<Value> yieldOperands;
126-
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
127+
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
127128
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
128-
}
129129

130130
assignValues(yieldOperands, resultVariables, rewriter, loc);
131131

@@ -336,11 +336,177 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
336336
return success();
337337
}
338338

339+
// Lower scf::while to emitc::do using mutable variables to maintain loop state
340+
// across iterations. The do-while structure ensures the condition is evaluated
341+
// after each iteration, matching SCF while semantics.
342+
struct WhileLowering : public OpConversionPattern<WhileOp> {
343+
using OpConversionPattern::OpConversionPattern;
344+
345+
LogicalResult
346+
matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
347+
ConversionPatternRewriter &rewriter) const override {
348+
Location loc = whileOp.getLoc();
349+
MLIRContext *context = loc.getContext();
350+
351+
// Create an emitc::variable op for each result. These variables will be
352+
// assigned to by emitc::assign ops within the loop body.
353+
SmallVector<Value> resultVariables;
354+
if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
355+
resultVariables)))
356+
return rewriter.notifyMatchFailure(whileOp,
357+
"Failed to create result variables");
358+
359+
// Create variable storage for loop-carried values to enable imperative
360+
// updates while maintaining SSA semantics at conversion boundaries.
361+
SmallVector<Value> loopVariables;
362+
if (failed(createVariablesForLoopCarriedValues(
363+
whileOp, rewriter, loopVariables, loc, context)))
364+
return failure();
365+
366+
if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
367+
rewriter, loc)))
368+
return failure();
369+
370+
rewriter.setInsertionPointAfter(whileOp);
371+
372+
// Load the final result values from result variables.
373+
SmallVector<Value> finalResults =
374+
loadValues(resultVariables, rewriter, loc);
375+
rewriter.replaceOp(whileOp, finalResults);
376+
377+
return success();
378+
}
379+
380+
private:
381+
// Initialize variables for loop-carried values to enable state updates
382+
// across iterations without SSA argument passing.
383+
LogicalResult createVariablesForLoopCarriedValues(
384+
WhileOp whileOp, ConversionPatternRewriter &rewriter,
385+
SmallVectorImpl<Value> &loopVars, Location loc,
386+
MLIRContext *context) const {
387+
OpBuilder::InsertionGuard guard(rewriter);
388+
rewriter.setInsertionPoint(whileOp);
389+
390+
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
391+
392+
for (Value init : whileOp.getInits()) {
393+
Type convertedType = getTypeConverter()->convertType(init.getType());
394+
if (!convertedType)
395+
return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
396+
397+
emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
398+
loc, emitc::LValueType::get(convertedType), noInit);
399+
rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
400+
loopVars.push_back(var);
401+
}
402+
403+
return success();
404+
}
405+
406+
// Lower scf.while to emitc.do.
407+
LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
408+
ArrayRef<Value> resultVars, MLIRContext *context,
409+
ConversionPatternRewriter &rewriter,
410+
Location loc) const {
411+
// Create a global boolean variable to store the loop condition state.
412+
Type i1Type = IntegerType::get(context, 1);
413+
auto globalCondition =
414+
rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
415+
emitc::OpaqueAttr::get(context, ""));
416+
Value conditionVal = globalCondition.getResult();
417+
418+
auto loweredDo = rewriter.create<emitc::DoOp>(loc);
419+
420+
// Convert region types to match the target dialect type system.
421+
if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
422+
*getTypeConverter(), nullptr)) ||
423+
failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
424+
*getTypeConverter(), nullptr))) {
425+
return rewriter.notifyMatchFailure(whileOp,
426+
"region types conversion failed");
427+
}
428+
429+
// Prepare the before region (condition evaluation) for merging.
430+
Block *beforeBlock = &whileOp.getBefore().front();
431+
Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
432+
rewriter.setInsertionPointToStart(bodyBlock);
433+
434+
// Load current variable values to use as initial arguments for the
435+
// condition block.
436+
SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
437+
rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
438+
439+
Operation *condTerminator =
440+
loweredDo.getBodyRegion().back().getTerminator();
441+
scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
442+
rewriter.setInsertionPoint(condOp);
443+
444+
// Update result variables with values from scf::condition.
445+
SmallVector<Value> conditionArgs;
446+
for (Value arg : condOp.getArgs()) {
447+
conditionArgs.push_back(rewriter.getRemappedValue(arg));
448+
}
449+
assignValues(conditionArgs, resultVars, rewriter, loc);
450+
451+
// Convert scf.condition to condition variable assignment.
452+
Value condition = rewriter.getRemappedValue(condOp.getCondition());
453+
rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);
454+
455+
// Wrap body region in conditional to preserve scf semantics. Only create
456+
// ifOp if after-region is non-empty.
457+
if (whileOp.getAfterBody()->getOperations().size() > 1) {
458+
auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);
459+
460+
// Prepare the after region (loop body) for merging.
461+
Block *afterBlock = &whileOp.getAfter().front();
462+
Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
463+
464+
// Replacement values for after block using condition op arguments.
465+
SmallVector<Value> afterReplacingValues;
466+
for (Value arg : condOp.getArgs())
467+
afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
468+
469+
rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
470+
471+
if (failed(lowerYield(whileOp, loopVars, rewriter,
472+
cast<scf::YieldOp>(ifBodyBlock->getTerminator()))))
473+
return failure();
474+
}
475+
476+
rewriter.eraseOp(condOp);
477+
478+
// Create condition region that loads from the flag variable.
479+
Region &condRegion = loweredDo.getConditionRegion();
480+
Block *condBlock = rewriter.createBlock(&condRegion);
481+
rewriter.setInsertionPointToStart(condBlock);
482+
483+
auto exprOp = rewriter.create<emitc::ExpressionOp>(
484+
loc, i1Type, conditionVal, /*do_not_inline=*/false);
485+
Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
486+
487+
// Set up the expression block to load the condition variable.
488+
exprBlock->addArgument(conditionVal.getType(), loc);
489+
rewriter.setInsertionPointToStart(exprBlock);
490+
491+
// Load the condition value and yield it as the expression result.
492+
Value cond =
493+
rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0));
494+
rewriter.create<emitc::YieldOp>(loc, cond);
495+
496+
// Yield the expression as the condition region result.
497+
rewriter.setInsertionPointToEnd(condBlock);
498+
rewriter.create<emitc::YieldOp>(loc, exprOp);
499+
500+
return success();
501+
}
502+
};
503+
339504
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
340505
TypeConverter &typeConverter) {
341506
patterns.add<ForLowering>(typeConverter, patterns.getContext());
342507
patterns.add<IfLowering>(typeConverter, patterns.getContext());
343508
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
509+
patterns.add<WhileLowering>(typeConverter, patterns.getContext());
344510
}
345511

346512
void SCFToEmitCPass::runOnOperation() {
@@ -357,7 +523,8 @@ void SCFToEmitCPass::runOnOperation() {
357523

358524
// Configure conversion to lower out SCF operations.
359525
ConversionTarget target(getContext());
360-
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
526+
target
527+
.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
361528
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
362529
if (failed(
363530
applyPartialConversion(getOperation(), target, std::move(patterns))))

0 commit comments

Comments
 (0)