21
21
#include " mlir/IR/PatternMatch.h"
22
22
#include " mlir/Transforms/DialectConversion.h"
23
23
#include " mlir/Transforms/Passes.h"
24
+ #include " llvm/Support/LogicalResult.h"
24
25
25
26
namespace mlir {
26
27
#define GEN_PASS_DEF_SCFTOEMITC
@@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables,
106
107
emitc::AssignOp::create (rewriter, loc, var, value);
107
108
}
108
109
109
- SmallVector<Value> loadValues (const SmallVector <Value> & variables,
110
+ SmallVector<Value> loadValues (ArrayRef <Value> variables,
110
111
PatternRewriter &rewriter, Location loc) {
111
112
return llvm::map_to_vector<>(variables, [&](Value var) {
112
113
Type type = cast<emitc::LValueType>(var.getType ()).getValueType ();
@@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
116
117
117
118
static LogicalResult lowerYield (Operation *op, ValueRange resultVariables,
118
119
ConversionPatternRewriter &rewriter,
119
- scf::YieldOp yield) {
120
+ scf::YieldOp yield, bool createYield = true ) {
120
121
Location loc = yield.getLoc ();
121
122
122
123
OpBuilder::InsertionGuard guard (rewriter);
123
124
rewriter.setInsertionPoint (yield);
124
125
125
126
SmallVector<Value> yieldOperands;
126
- if (failed (rewriter.getRemappedValues (yield.getOperands (), yieldOperands))) {
127
+ if (failed (rewriter.getRemappedValues (yield.getOperands (), yieldOperands)))
127
128
return rewriter.notifyMatchFailure (op, " failed to lower yield operands" );
128
- }
129
129
130
130
assignValues (yieldOperands, resultVariables, rewriter, loc);
131
131
@@ -336,11 +336,177 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
336
336
return success ();
337
337
}
338
338
339
+ // Lower scf::while to emitc::do using mutable variables to maintain loop state
340
+ // across iterations. The do-while structure ensures the condition is evaluated
341
+ // after each iteration, matching SCF while semantics.
342
+ struct WhileLowering : public OpConversionPattern <WhileOp> {
343
+ using OpConversionPattern::OpConversionPattern;
344
+
345
+ LogicalResult
346
+ matchAndRewrite (WhileOp whileOp, OpAdaptor adaptor,
347
+ ConversionPatternRewriter &rewriter) const override {
348
+ Location loc = whileOp.getLoc ();
349
+ MLIRContext *context = loc.getContext ();
350
+
351
+ // Create an emitc::variable op for each result. These variables will be
352
+ // assigned to by emitc::assign ops within the loop body.
353
+ SmallVector<Value> resultVariables;
354
+ if (failed (createVariablesForResults (whileOp, getTypeConverter (), rewriter,
355
+ resultVariables)))
356
+ return rewriter.notifyMatchFailure (whileOp,
357
+ " Failed to create result variables" );
358
+
359
+ // Create variable storage for loop-carried values to enable imperative
360
+ // updates while maintaining SSA semantics at conversion boundaries.
361
+ SmallVector<Value> loopVariables;
362
+ if (failed (createVariablesForLoopCarriedValues (
363
+ whileOp, rewriter, loopVariables, loc, context)))
364
+ return failure ();
365
+
366
+ if (failed (lowerDoWhile (whileOp, loopVariables, resultVariables, context,
367
+ rewriter, loc)))
368
+ return failure ();
369
+
370
+ rewriter.setInsertionPointAfter (whileOp);
371
+
372
+ // Load the final result values from result variables.
373
+ SmallVector<Value> finalResults =
374
+ loadValues (resultVariables, rewriter, loc);
375
+ rewriter.replaceOp (whileOp, finalResults);
376
+
377
+ return success ();
378
+ }
379
+
380
+ private:
381
+ // Initialize variables for loop-carried values to enable state updates
382
+ // across iterations without SSA argument passing.
383
+ LogicalResult createVariablesForLoopCarriedValues (
384
+ WhileOp whileOp, ConversionPatternRewriter &rewriter,
385
+ SmallVectorImpl<Value> &loopVars, Location loc,
386
+ MLIRContext *context) const {
387
+ OpBuilder::InsertionGuard guard (rewriter);
388
+ rewriter.setInsertionPoint (whileOp);
389
+
390
+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
391
+
392
+ for (Value init : whileOp.getInits ()) {
393
+ Type convertedType = getTypeConverter ()->convertType (init.getType ());
394
+ if (!convertedType)
395
+ return rewriter.notifyMatchFailure (whileOp, " type conversion failed" );
396
+
397
+ emitc::VariableOp var = rewriter.create <emitc::VariableOp>(
398
+ loc, emitc::LValueType::get (convertedType), noInit);
399
+ rewriter.create <emitc::AssignOp>(loc, var.getResult (), init);
400
+ loopVars.push_back (var);
401
+ }
402
+
403
+ return success ();
404
+ }
405
+
406
+ // Lower scf.while to emitc.do.
407
+ LogicalResult lowerDoWhile (WhileOp whileOp, ArrayRef<Value> loopVars,
408
+ ArrayRef<Value> resultVars, MLIRContext *context,
409
+ ConversionPatternRewriter &rewriter,
410
+ Location loc) const {
411
+ // Create a global boolean variable to store the loop condition state.
412
+ Type i1Type = IntegerType::get (context, 1 );
413
+ auto globalCondition =
414
+ rewriter.create <emitc::VariableOp>(loc, emitc::LValueType::get (i1Type),
415
+ emitc::OpaqueAttr::get (context, " " ));
416
+ Value conditionVal = globalCondition.getResult ();
417
+
418
+ auto loweredDo = rewriter.create <emitc::DoOp>(loc);
419
+
420
+ // Convert region types to match the target dialect type system.
421
+ if (failed (rewriter.convertRegionTypes (&whileOp.getBefore (),
422
+ *getTypeConverter (), nullptr )) ||
423
+ failed (rewriter.convertRegionTypes (&whileOp.getAfter (),
424
+ *getTypeConverter (), nullptr ))) {
425
+ return rewriter.notifyMatchFailure (whileOp,
426
+ " region types conversion failed" );
427
+ }
428
+
429
+ // Prepare the before region (condition evaluation) for merging.
430
+ Block *beforeBlock = &whileOp.getBefore ().front ();
431
+ Block *bodyBlock = rewriter.createBlock (&loweredDo.getBodyRegion ());
432
+ rewriter.setInsertionPointToStart (bodyBlock);
433
+
434
+ // Load current variable values to use as initial arguments for the
435
+ // condition block.
436
+ SmallVector<Value> replacingValues = loadValues (loopVars, rewriter, loc);
437
+ rewriter.mergeBlocks (beforeBlock, bodyBlock, replacingValues);
438
+
439
+ Operation *condTerminator =
440
+ loweredDo.getBodyRegion ().back ().getTerminator ();
441
+ scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
442
+ rewriter.setInsertionPoint (condOp);
443
+
444
+ // Update result variables with values from scf::condition.
445
+ SmallVector<Value> conditionArgs;
446
+ for (Value arg : condOp.getArgs ()) {
447
+ conditionArgs.push_back (rewriter.getRemappedValue (arg));
448
+ }
449
+ assignValues (conditionArgs, resultVars, rewriter, loc);
450
+
451
+ // Convert scf.condition to condition variable assignment.
452
+ Value condition = rewriter.getRemappedValue (condOp.getCondition ());
453
+ rewriter.create <emitc::AssignOp>(loc, conditionVal, condition);
454
+
455
+ // Wrap body region in conditional to preserve scf semantics. Only create
456
+ // ifOp if after-region is non-empty.
457
+ if (whileOp.getAfterBody ()->getOperations ().size () > 1 ) {
458
+ auto ifOp = rewriter.create <emitc::IfOp>(loc, condition, false , false );
459
+
460
+ // Prepare the after region (loop body) for merging.
461
+ Block *afterBlock = &whileOp.getAfter ().front ();
462
+ Block *ifBodyBlock = rewriter.createBlock (&ifOp.getBodyRegion ());
463
+
464
+ // Replacement values for after block using condition op arguments.
465
+ SmallVector<Value> afterReplacingValues;
466
+ for (Value arg : condOp.getArgs ())
467
+ afterReplacingValues.push_back (rewriter.getRemappedValue (arg));
468
+
469
+ rewriter.mergeBlocks (afterBlock, ifBodyBlock, afterReplacingValues);
470
+
471
+ if (failed (lowerYield (whileOp, loopVars, rewriter,
472
+ cast<scf::YieldOp>(ifBodyBlock->getTerminator ()))))
473
+ return failure ();
474
+ }
475
+
476
+ rewriter.eraseOp (condOp);
477
+
478
+ // Create condition region that loads from the flag variable.
479
+ Region &condRegion = loweredDo.getConditionRegion ();
480
+ Block *condBlock = rewriter.createBlock (&condRegion);
481
+ rewriter.setInsertionPointToStart (condBlock);
482
+
483
+ auto exprOp = rewriter.create <emitc::ExpressionOp>(
484
+ loc, i1Type, conditionVal, /* do_not_inline=*/ false );
485
+ Block *exprBlock = rewriter.createBlock (&exprOp.getBodyRegion ());
486
+
487
+ // Set up the expression block to load the condition variable.
488
+ exprBlock->addArgument (conditionVal.getType (), loc);
489
+ rewriter.setInsertionPointToStart (exprBlock);
490
+
491
+ // Load the condition value and yield it as the expression result.
492
+ Value cond =
493
+ rewriter.create <emitc::LoadOp>(loc, i1Type, exprBlock->getArgument (0 ));
494
+ rewriter.create <emitc::YieldOp>(loc, cond);
495
+
496
+ // Yield the expression as the condition region result.
497
+ rewriter.setInsertionPointToEnd (condBlock);
498
+ rewriter.create <emitc::YieldOp>(loc, exprOp);
499
+
500
+ return success ();
501
+ }
502
+ };
503
+
339
504
void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns,
340
505
TypeConverter &typeConverter) {
341
506
patterns.add <ForLowering>(typeConverter, patterns.getContext ());
342
507
patterns.add <IfLowering>(typeConverter, patterns.getContext ());
343
508
patterns.add <IndexSwitchOpLowering>(typeConverter, patterns.getContext ());
509
+ patterns.add <WhileLowering>(typeConverter, patterns.getContext ());
344
510
}
345
511
346
512
void SCFToEmitCPass::runOnOperation () {
@@ -357,7 +523,8 @@ void SCFToEmitCPass::runOnOperation() {
357
523
358
524
// Configure conversion to lower out SCF operations.
359
525
ConversionTarget target (getContext ());
360
- target.addIllegalOp <scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
526
+ target
527
+ .addIllegalOp <scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
361
528
target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
362
529
if (failed (
363
530
applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments