@@ -1327,6 +1327,15 @@ class RewriteRegionBranchOp
1327
1327
llvm::SmallVector<RegionSuccessor> successors;
1328
1328
iface.getSuccessorRegions (RegionBranchPoint::parent (), successors);
1329
1329
1330
+ // SCF::WhileOp has two regions, named before and after respectively.
1331
+ // For parent point, it returns the before region,
1332
+ if (auto whileOp = dyn_cast_or_null<scf::WhileOp>(op)) {
1333
+ iface.getSuccessorRegions (whileOp.getBefore (), successors);
1334
+ }
1335
+
1336
+ if (successors.size () == 0 )
1337
+ return failure ();
1338
+
1330
1339
// the region iter arguments will be used as the anchor if it is a loop,
1331
1340
// otherwise, the op results will be used as the anchor.
1332
1341
// TODO: is it safe to assume that first is always the entry successor?
@@ -1353,62 +1362,77 @@ class RewriteRegionBranchOp
1353
1362
auto defaultIP = rewriter.saveInsertionPoint ();
1354
1363
PatternRewriter::InsertionGuard g (rewriter);
1355
1364
1356
- for (auto s : successors) { // convert the terminator
1365
+ // convert the terminators and arguments of each region
1366
+ for (auto [i, s] : llvm::enumerate (successors)) {
1357
1367
if (s.isParent ())
1358
1368
continue ;
1369
+
1359
1370
Region *r = s.getSuccessor ();
1360
- auto terminator = r->front ().getTerminator ();
1361
- llvm::SmallVector<Value> convertedOperands;
1362
- rewriter.setInsertionPoint (terminator);
1363
- convertOperandsOrResults (
1364
- terminator->getOperands (), blockSZs,
1365
- [&](int64_t i, Value v, ShapedType type,
1366
- llvm::ArrayRef<int64_t > blockSZ) {
1367
- auto newTypes = convertTypes (type, blockSZ, arrayLengthAttrs[i]);
1368
- auto newOprs = addPackOp (v, newTypes, blockSZ, loc, rewriter);
1369
- convertedOperands.append (newOprs.begin (), newOprs.end ());
1370
- },
1371
- [&](int64_t i, Value v) { convertedOperands.push_back (v); });
1372
1371
1373
- terminator->setOperands (convertedOperands);
1372
+ { // convert the terminator
1373
+ auto terminator = r->front ().getTerminator ();
1374
+ rewriter.setInsertionPoint (terminator);
1375
+
1376
+ llvm::SmallVector<Value> convertedOperands;
1377
+ auto operands = terminator->getOpOperands ();
1378
+ // the condition operand of ConditionOp needs no conversions
1379
+ if (isa<scf::ConditionOp>(terminator)) {
1380
+ convertedOperands.push_back (operands[0 ].get ());
1381
+ operands = operands.drop_front ();
1382
+ }
1383
+
1384
+ convertOperandsOrResults (
1385
+ OperandRange (operands.data (), operands.size ()), blockSZs,
1386
+ [&](int64_t i, Value v, ShapedType type,
1387
+ llvm::ArrayRef<int64_t > blockSZ) {
1388
+ auto newTypes = convertTypes (type, blockSZ, arrayLengthAttrs[i]);
1389
+ auto newOprs = addPackOp (v, newTypes, blockSZ, loc, rewriter);
1390
+ convertedOperands.append (newOprs.begin (), newOprs.end ());
1391
+ },
1392
+ [&](int64_t i, Value v) { convertedOperands.push_back (v); });
1393
+
1394
+ terminator->setOperands (convertedOperands);
1395
+ } // end of convert the terminator
1396
+
1397
+ { // convert the region arguments for loops
1398
+ if (iface.hasLoop ()) {
1399
+ rewriter.setInsertionPointToStart (&r->front ());
1400
+ auto arguments = llvm::to_vector (s.getSuccessorInputs ());
1401
+ convertOperandsOrResults (
1402
+ llvm::ArrayRef<Value>(arguments), blockSZs,
1403
+ [&](int64_t i, Value arg, ShapedType type,
1404
+ llvm::ArrayRef<int64_t > blockSZ) {
1405
+ auto newTypes =
1406
+ convertTypes (type, blockSZ, arrayLengthAttrs[i]);
1407
+ llvm::SmallVector<Location> locs (newTypes.size (), arg.getLoc ());
1408
+ llvm::SmallVector<Value> newArgs;
1409
+ llvm::for_each (r->addArguments (newTypes, locs),
1410
+ [&](BlockArgument b) { newArgs.push_back (b); });
1411
+ auto cast = addUnpackOp (newArgs, type, blockSZ, loc, rewriter);
1412
+ arg.replaceAllUsesWith (cast);
1413
+ },
1414
+ [&](int64_t i, Value arg) {
1415
+ auto newArg = r->addArgument (arg.getType (), arg.getLoc ());
1416
+ arg.replaceAllUsesWith (newArg);
1417
+ });
1418
+
1419
+ // cleanup the old arguments, it has to done in reverse order
1420
+ for (auto v : llvm::reverse (arguments)) {
1421
+ auto arg = dyn_cast<BlockArgument>(v);
1422
+ if (arg && arg.use_empty ())
1423
+ r->eraseArgument (arg.getArgNumber ());
1424
+ }
1425
+ } // end of iface.hasLoop()
1426
+ } // end of convert the region arguments
1374
1427
}
1375
1428
1376
- // convert BlockArguments and Inits if it is a loop, otherwise original
1377
- // inputs will used
1429
+ // convert BlockArguments and Inits if it is a loop,
1430
+ // otherwise original inputs will used
1378
1431
llvm::SmallVector<Value> convertedOperands (op->getOperands ());
1379
- if (iface.hasLoop ()) {
1380
- RegionSuccessor s = successors[0 ];
1381
- Region *r = s.getSuccessor ();
1382
- rewriter.setInsertionPointToStart (&r->front ());
1383
- auto arguments = llvm::to_vector (s.getSuccessorInputs ());
1384
- convertOperandsOrResults (
1385
- llvm::ArrayRef<Value>(arguments), blockSZs,
1386
- [&](int64_t i, Value arg, ShapedType type,
1387
- llvm::ArrayRef<int64_t > blockSZ) {
1388
- auto newTypes = convertTypes (type, blockSZ, arrayLengthAttrs[i]);
1389
- llvm::SmallVector<Location> locs (newTypes.size (), arg.getLoc ());
1390
- llvm::SmallVector<Value> newArgs;
1391
- llvm::for_each (r->addArguments (newTypes, locs),
1392
- [&](BlockArgument b) { newArgs.push_back (b); });
1393
- auto cast = addUnpackOp (newArgs, type, blockSZ, loc, rewriter);
1394
- arg.replaceAllUsesWith (cast);
1395
- },
1396
- [&](int64_t i, Value arg) {
1397
- auto newArg = r->addArgument (arg.getType (), arg.getLoc ());
1398
- arg.replaceAllUsesWith (newArg);
1399
- });
1400
-
1401
- // cleanup the old arguments, it has to done in reverse order
1402
- for (auto v : llvm::reverse (arguments)) {
1403
- auto arg = dyn_cast<BlockArgument>(v);
1404
- if (arg && arg.use_empty ())
1405
- r->eraseArgument (arg.getArgNumber ());
1406
- }
1407
-
1408
- // convert the Inits
1432
+ if (auto loop = dyn_cast_or_null<LoopLikeOpInterface>(op)) {
1409
1433
rewriter.setInsertionPoint (op);
1434
+ auto inits = loop.getInits ();
1410
1435
1411
- auto inits = iface.getEntrySuccessorOperands (s);
1412
1436
convertedOperands.pop_back_n (inits.size ());
1413
1437
convertOperandsOrResults (
1414
1438
inits, blockSZs,
0 commit comments