Skip to content

Commit 7d31574

Browse files
committed
Yield if condition, range-based escaping values for innerwarps
1 parent b356d11 commit 7d31574

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,14 +1794,18 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
17941794
return failure();
17951795

17961796
// The new `WarpOp` groups yields values in following order:
1797-
// 1. Escaping values then branch
1798-
// 2. Escaping values else branch
1799-
// 3. All non-`ifOp` yielded values.
1800-
SmallVector<Value> newWarpOpYieldValues{escapingValuesThen.begin(),
1801-
escapingValuesThen.end()};
1797+
// 1. Branch condition
1798+
// 2. Escaping values then branch
1799+
// 3. Escaping values else branch
1800+
// 4. All non-`ifOp` yielded values.
1801+
SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1802+
newWarpOpYieldValues.append(escapingValuesThen.begin(),
1803+
escapingValuesThen.end());
18021804
newWarpOpYieldValues.append(escapingValuesElse.begin(),
18031805
escapingValuesElse.end());
1804-
SmallVector<Type> newWarpOpDistTypes = escapingValueDistTypesThen;
1806+
SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1807+
newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1808+
escapingValueDistTypesThen.end());
18051809
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
18061810
escapingValueDistTypesElse.end());
18071811

@@ -1815,7 +1819,6 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18151819
// Create the new `WarpOp` with the updated yield values and types.
18161820
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
18171821
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1818-
18191822
// `ifOp` returns the result of the inner warp op.
18201823
SmallVector<Type> newIfOpDistResTypes;
18211824
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
@@ -1831,23 +1834,24 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18311834
// Create a new `IfOp` outside the new `WarpOp` region.
18321835
OpBuilder::InsertionGuard g(rewriter);
18331836
rewriter.setInsertionPointAfter(newWarpOp);
1834-
auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(),
1835-
newIfOpDistResTypes, ifOp.getCondition(),
1836-
static_cast<bool>(ifOp.thenBlock()),
1837-
static_cast<bool>(ifOp.elseBlock()));
1837+
auto newIfOp = scf::IfOp::create(
1838+
rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
1839+
static_cast<bool>(ifOp.thenBlock()),
1840+
static_cast<bool>(ifOp.elseBlock()));
18381841

18391842
auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
18401843
llvm::SmallSetVector<Value, 32> &escapingValues,
1841-
SmallVector<Type> &escapingValueInputTypes) {
1844+
SmallVector<Type> &escapingValueInputTypes,
1845+
size_t warpResRangeStart) {
18421846
OpBuilder::InsertionGuard g(rewriter);
18431847
if (!newIfBranch)
18441848
return;
18451849
rewriter.setInsertionPointToStart(newIfBranch);
18461850
llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
18471851
SmallVector<Value> innerWarpInputVals;
18481852
SmallVector<Type> innerWarpInputTypes;
1849-
for (size_t i = 0; i < escapingValues.size(); ++i) {
1850-
innerWarpInputVals.push_back(newWarpOp.getResult(i));
1853+
for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) {
1854+
innerWarpInputVals.push_back(newWarpOp.getResult(warpResRangeStart));
18511855
escapeValToBlockArgIndex[escapingValues[i]] =
18521856
innerWarpInputTypes.size();
18531857
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
@@ -1886,11 +1890,11 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18861890
};
18871891
processBranch(&ifOp.getThenRegion().front(),
18881892
&newIfOp.getThenRegion().front(), escapingValuesThen,
1889-
escapingValueInputTypesThen);
1893+
escapingValueInputTypesThen, 1);
18901894
if (!ifOp.getElseRegion().empty())
18911895
processBranch(&ifOp.getElseRegion().front(),
18921896
&newIfOp.getElseRegion().front(), escapingValuesElse,
1893-
escapingValueInputTypesElse);
1897+
escapingValueInputTypesElse, 1 + escapingValuesThen.size());
18941898
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
18951899
// result.
18961900
for (auto [origIdx, newIdx] : ifResultMapping)

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,11 @@ gpu.module @test {
339339
}
340340

341341
// -----
342-
// CHECK-LABEL: gpu.func @scatter_ops_scf_yield({{.*}}) {
342+
// CHECK-LABEL: gpu.func @scatter_ops_scf_yield({{.*}},
343+
// CHECK-SAME: %[[PREDICATE:.*]]: i1) {
343344
// CHECK: %[[DEFAULT:.*]] = arith.constant dense<1.200000e+01> : vector<8xf16>
344345
// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
345346
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
346-
// CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1
347347
// CHECK: %[[PREDICATED_LOAD:.*]] = scf.if %[[PREDICATE]] -> (vector<8xf16>) {
348348
// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
349349
// CHECK-NEXT: scf.yield %[[LOADED]] : vector<8xf16>
@@ -352,8 +352,7 @@ gpu.module @test {
352352
// CHECK-NEXT: }
353353
// CHECK-NEXT: xegpu.store %[[PREDICATED_LOAD]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
354354
gpu.module @test {
355-
gpu.func @scatter_ops_scf_yield(%src: memref<256xf16>) {
356-
%pred = llvm.mlir.poison : i1
355+
gpu.func @scatter_ops_scf_yield(%src: memref<256xf16>, %pred : i1) {
357356
%1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
358357
%offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
359358
%loaded = scf.if %pred -> (vector<16x8xf16>) {

0 commit comments

Comments
 (0)