Skip to content

Commit a74c3a4

Browse files
authored
[Blocking] Refine RegionBranchOp pattern for WhileOp (#1065)
1 parent 299f2d2 commit a74c3a4

File tree

3 files changed

+120
-48
lines changed

3 files changed

+120
-48
lines changed

lib/Dialect/XeTile/Transforms/Blocking.cpp

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,15 @@ class RewriteRegionBranchOp
13271327
llvm::SmallVector<RegionSuccessor> successors;
13281328
iface.getSuccessorRegions(RegionBranchPoint::parent(), successors);
13291329

1330+
// SCF::WhileOp has two regions, named before and after respectively.
1331+
// For parent point, it returns the before region,
1332+
if (auto whileOp = dyn_cast_or_null<scf::WhileOp>(op)) {
1333+
iface.getSuccessorRegions(whileOp.getBefore(), successors);
1334+
}
1335+
1336+
if (successors.size() == 0)
1337+
return failure();
1338+
13301339
// the region iter arguments will be used as the anchor if it is a loop,
13311340
// otherwise, the op results will be used as the anchor.
13321341
// TODO: is it safe to assume that first is always the entry successor?
@@ -1353,62 +1362,77 @@ class RewriteRegionBranchOp
13531362
auto defaultIP = rewriter.saveInsertionPoint();
13541363
PatternRewriter::InsertionGuard g(rewriter);
13551364

1356-
for (auto s : successors) { // convert the terminator
1365+
// convert the terminators and arguments of each region
1366+
for (auto [i, s] : llvm::enumerate(successors)) {
13571367
if (s.isParent())
13581368
continue;
1369+
13591370
Region *r = s.getSuccessor();
1360-
auto terminator = r->front().getTerminator();
1361-
llvm::SmallVector<Value> convertedOperands;
1362-
rewriter.setInsertionPoint(terminator);
1363-
convertOperandsOrResults(
1364-
terminator->getOperands(), blockSZs,
1365-
[&](int64_t i, Value v, ShapedType type,
1366-
llvm::ArrayRef<int64_t> blockSZ) {
1367-
auto newTypes = convertTypes(type, blockSZ, arrayLengthAttrs[i]);
1368-
auto newOprs = addPackOp(v, newTypes, blockSZ, loc, rewriter);
1369-
convertedOperands.append(newOprs.begin(), newOprs.end());
1370-
},
1371-
[&](int64_t i, Value v) { convertedOperands.push_back(v); });
13721371

1373-
terminator->setOperands(convertedOperands);
1372+
{ // convert the terminator
1373+
auto terminator = r->front().getTerminator();
1374+
rewriter.setInsertionPoint(terminator);
1375+
1376+
llvm::SmallVector<Value> convertedOperands;
1377+
auto operands = terminator->getOpOperands();
1378+
// the condition operand of ConditionOp needs no conversions
1379+
if (isa<scf::ConditionOp>(terminator)) {
1380+
convertedOperands.push_back(operands[0].get());
1381+
operands = operands.drop_front();
1382+
}
1383+
1384+
convertOperandsOrResults(
1385+
OperandRange(operands.data(), operands.size()), blockSZs,
1386+
[&](int64_t i, Value v, ShapedType type,
1387+
llvm::ArrayRef<int64_t> blockSZ) {
1388+
auto newTypes = convertTypes(type, blockSZ, arrayLengthAttrs[i]);
1389+
auto newOprs = addPackOp(v, newTypes, blockSZ, loc, rewriter);
1390+
convertedOperands.append(newOprs.begin(), newOprs.end());
1391+
},
1392+
[&](int64_t i, Value v) { convertedOperands.push_back(v); });
1393+
1394+
terminator->setOperands(convertedOperands);
1395+
} // end of convert the terminator
1396+
1397+
{ // convert the region arguments for loops
1398+
if (iface.hasLoop()) {
1399+
rewriter.setInsertionPointToStart(&r->front());
1400+
auto arguments = llvm::to_vector(s.getSuccessorInputs());
1401+
convertOperandsOrResults(
1402+
llvm::ArrayRef<Value>(arguments), blockSZs,
1403+
[&](int64_t i, Value arg, ShapedType type,
1404+
llvm::ArrayRef<int64_t> blockSZ) {
1405+
auto newTypes =
1406+
convertTypes(type, blockSZ, arrayLengthAttrs[i]);
1407+
llvm::SmallVector<Location> locs(newTypes.size(), arg.getLoc());
1408+
llvm::SmallVector<Value> newArgs;
1409+
llvm::for_each(r->addArguments(newTypes, locs),
1410+
[&](BlockArgument b) { newArgs.push_back(b); });
1411+
auto cast = addUnpackOp(newArgs, type, blockSZ, loc, rewriter);
1412+
arg.replaceAllUsesWith(cast);
1413+
},
1414+
[&](int64_t i, Value arg) {
1415+
auto newArg = r->addArgument(arg.getType(), arg.getLoc());
1416+
arg.replaceAllUsesWith(newArg);
1417+
});
1418+
1419+
// cleanup the old arguments, it has to done in reverse order
1420+
for (auto v : llvm::reverse(arguments)) {
1421+
auto arg = dyn_cast<BlockArgument>(v);
1422+
if (arg && arg.use_empty())
1423+
r->eraseArgument(arg.getArgNumber());
1424+
}
1425+
} // end of iface.hasLoop()
1426+
} // end of convert the region arguments
13741427
}
13751428

1376-
// convert BlockArguments and Inits if it is a loop, otherwise original
1377-
// inputs will used
1429+
// convert BlockArguments and Inits if it is a loop,
1430+
// otherwise original inputs will used
13781431
llvm::SmallVector<Value> convertedOperands(op->getOperands());
1379-
if (iface.hasLoop()) {
1380-
RegionSuccessor s = successors[0];
1381-
Region *r = s.getSuccessor();
1382-
rewriter.setInsertionPointToStart(&r->front());
1383-
auto arguments = llvm::to_vector(s.getSuccessorInputs());
1384-
convertOperandsOrResults(
1385-
llvm::ArrayRef<Value>(arguments), blockSZs,
1386-
[&](int64_t i, Value arg, ShapedType type,
1387-
llvm::ArrayRef<int64_t> blockSZ) {
1388-
auto newTypes = convertTypes(type, blockSZ, arrayLengthAttrs[i]);
1389-
llvm::SmallVector<Location> locs(newTypes.size(), arg.getLoc());
1390-
llvm::SmallVector<Value> newArgs;
1391-
llvm::for_each(r->addArguments(newTypes, locs),
1392-
[&](BlockArgument b) { newArgs.push_back(b); });
1393-
auto cast = addUnpackOp(newArgs, type, blockSZ, loc, rewriter);
1394-
arg.replaceAllUsesWith(cast);
1395-
},
1396-
[&](int64_t i, Value arg) {
1397-
auto newArg = r->addArgument(arg.getType(), arg.getLoc());
1398-
arg.replaceAllUsesWith(newArg);
1399-
});
1400-
1401-
// cleanup the old arguments, it has to done in reverse order
1402-
for (auto v : llvm::reverse(arguments)) {
1403-
auto arg = dyn_cast<BlockArgument>(v);
1404-
if (arg && arg.use_empty())
1405-
r->eraseArgument(arg.getArgNumber());
1406-
}
1407-
1408-
// convert the Inits
1432+
if (auto loop = dyn_cast_or_null<LoopLikeOpInterface>(op)) {
14091433
rewriter.setInsertionPoint(op);
1434+
auto inits = loop.getInits();
14101435

1411-
auto inits = iface.getEntrySuccessorOperands(s);
14121436
convertedOperands.pop_back_n(inits.size());
14131437
convertOperandsOrResults(
14141438
inits, blockSZs,

lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,17 @@ void BlockingAnalysis::printAnalysisResult() {
996996
<< ", arrayLen: " << getArrayLength(inputOpr);
997997
}
998998
llvm::dbgs() << "\n";
999+
} else if (auto WhileOp = mlir::dyn_cast<mlir::scf::WhileOp>(op)) {
1000+
llvm::dbgs() << "\nOp: " << op->getName();
1001+
for (auto [i, arg] : llvm::enumerate(WhileOp.getBefore().getArguments()))
1002+
llvm::dbgs() << "\n before arg[" << i << "]: "
1003+
<< " --> blkSZ: " << getDefBlockSize(arg)
1004+
<< ", arrayLen: " << getArrayLength(arg);
1005+
for (auto [i, arg] : llvm::enumerate(WhileOp.getAfter().getArguments()))
1006+
llvm::dbgs() << "\n after arg[" << i << "]: "
1007+
<< " --> blkSZ: " << getDefBlockSize(arg)
1008+
<< ", arrayLen: " << getArrayLength(arg);
1009+
llvm::dbgs() << "\n";
9991010
}
10001011
});
10011012
}

test/Dialect/XeTile/Transforms/Blocking/unit_tests.mlir

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1665,5 +1665,42 @@ gpu.module @test_kernel {
16651665
gpu.return
16661666
}
16671667

1668-
1668+
//-----
1669+
gpu.func @while_loop_kernel(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array<i32: 32, 1, 1>, known_grid_size = array<i32: 1, 1, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
1670+
%c0_i32 = arith.constant 0 : i32
1671+
%c1_i32 = arith.constant 1 : i32
1672+
%c10_i32 = arith.constant 10 : i32
1673+
%block_id_x = gpu.block_id x
1674+
%thread_id_x = gpu.thread_id x
1675+
%cast = memref.cast %arg2 : memref<*xf32> to memref<?xf32>
1676+
%cast_0 = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
1677+
%cast_1 = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
1678+
%index = arith.addi %block_id_x, %thread_id_x : index
1679+
%tile_indices = vector.splat %index : vector<1x256xindex>
1680+
%tile_indices_i32 = arith.index_cast %tile_indices : vector<1x256xindex> to vector<1x256xi32>
1681+
%arg3_splat = vector.splat %arg3 : vector<1x256xi32>
1682+
%mask = arith.cmpi slt, %tile_indices_i32, %arg3_splat : vector<1x256xi32>
1683+
%tile = xetile.init_tile %cast_1, %tile_indices : memref<?xf32>, vector<1x256xindex> -> !xetile.tile<1x256xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>
1684+
%load1 = xetile.load %tile, %mask : !xetile.tile<1x256xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>, vector<1x256xi1> -> vector<1x256xf32>
1685+
%tile_0 = xetile.init_tile %cast_0, %tile_indices : memref<?xf32>, vector<1x256xindex> -> !xetile.tile<1x256xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>
1686+
%load2 = xetile.load %tile_0, %mask : !xetile.tile<1x256xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>, vector<1x256xi1> -> vector<1x256xf32>
1687+
%sum = arith.addf %load1, %load2 : vector<1x256xf32>
1688+
%result:2 = scf.while (%arg4 = %sum, %arg5 = %c0_i32) : (vector<1x256xf32>, i32) -> (vector<1x256xf32>, i32) {
1689+
%cond = arith.cmpi slt, %arg5, %c10_i32 : i32
1690+
scf.condition(%cond) %arg4, %arg5 : vector<1x256xf32>, i32
1691+
} do {
1692+
^bb0(%arg4: vector<1x256xf32>, %arg5: i32):
1693+
//CHECK-COUNT-16: arith.addf {{.*}} : vector<1x16xf32>
1694+
%new_sum = arith.addf %arg4, %load1 : vector<1x256xf32>
1695+
//CHECK-COUNT-16: arith.addf {{.*}} : vector<1x16xf32>
1696+
%new_sum2 = arith.addf %new_sum, %load2 : vector<1x256xf32>
1697+
%new_iter = arith.addi %arg5, %c1_i32 : i32
1698+
scf.yield %new_sum2, %new_iter : vector<1x256xf32>, i32
1699+
}
1700+
//CHECK-COUNT-16: xetile.init_tile {{.*}} : memref<?xf32>, vector<1x16xindex> -> !xetile.tile<1x16xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>
1701+
%tile_out = xetile.init_tile %cast, %tile_indices : memref<?xf32>, vector<1x256xindex> -> !xetile.tile<1x256xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>
1702+
//CHECK-COUNT-16: xetile.store {{.*}} : vector<1x16xf32>, !xetile.tile<1x16xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>, vector<1x16xi1>
1703+
xetile.store %result#0, %tile_out, %mask : vector<1x256xf32>, !xetile.tile<1x256xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>, vector<1x256xi1>
1704+
gpu.return
1705+
}
16691706
}

0 commit comments

Comments
 (0)