Skip to content

Commit 7f007b5

Browse files
authored
[MLIR][Vector] Add warp distribution for scf.if (#157119)
This PR adds `scf.if` op distribution to the existing `VectorDistribute` patterns. The logic mostly follows that of `scf.for`: move op outside, wrap each branch with `gpu.warp_execute_on_lane_0`. A notable difference to `scf.for` is that each branch has its own set of escaping values, and `scf.if` itself does not have block arguments.
1 parent 827d775 commit 7f007b5

File tree

3 files changed

+372
-19
lines changed

3 files changed

+372
-19
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 246 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,38 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
371371
return targetType;
372372
}
373373

374+
/// Given a warpOp that contains ops with regions, the corresponding op's
375+
/// "inner" region and the distributionMapFn, get all values used by the op's
376+
/// region that are defined within the warpOp, but outside the inner region.
377+
/// Return the set of values, their types and their distributed types.
378+
std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
379+
SmallVector<Type>>
380+
getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
381+
DistributionMapFn distributionMapFn) {
382+
llvm::SmallSetVector<Value, 32> escapingValues;
383+
SmallVector<Type> escapingValueTypes;
384+
SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
385+
if (innerRegion.empty())
386+
return {std::move(escapingValues), std::move(escapingValueTypes),
387+
std::move(escapingValueDistTypes)};
388+
mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
389+
Operation *parent = operand->get().getParentRegion()->getParentOp();
390+
if (warpOp->isAncestor(parent)) {
391+
if (!escapingValues.insert(operand->get()))
392+
return;
393+
Type distType = operand->get().getType();
394+
if (auto vecType = dyn_cast<VectorType>(distType)) {
395+
AffineMap map = distributionMapFn(operand->get());
396+
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
397+
}
398+
escapingValueTypes.push_back(operand->get().getType());
399+
escapingValueDistTypes.push_back(distType);
400+
}
401+
});
402+
return {std::move(escapingValues), std::move(escapingValueTypes),
403+
std::move(escapingValueDistTypes)};
404+
}
405+
374406
/// Distribute transfer_write ops based on the affine map returned by
375407
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
376408
/// will not be distributed (it should be less than the warp size).
@@ -1713,6 +1745,215 @@ struct WarpOpInsert : public WarpDistributionPattern {
17131745
}
17141746
};
17151747

1748+
/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1749+
/// the scf.if is the last operation in the region so that it doesn't
1750+
/// change the order of execution. This creates a new scf.if after the
1751+
/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1752+
/// the "inner" WarpExecuteOnLane0Op. Example:
1753+
/// ```
1754+
/// gpu.warp_execute_on_lane_0(%laneid)[32] {
1755+
/// %payload = ... : vector<32xindex>
1756+
/// scf.if %pred {
1757+
/// vector.store %payload, %buffer[%idx] : memref<128xindex>,
1758+
/// vector<32xindex>
1759+
/// }
1760+
/// gpu.yield
1761+
/// }
1762+
/// ```
1763+
/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1764+
/// %payload = ... : vector<32xindex>
1765+
/// gpu.yield %payload : vector<32xindex>
1766+
/// }
1767+
/// scf.if %pred {
1768+
/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1769+
/// ^bb0(%arg1: vector<32xindex>):
1770+
/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1771+
/// }
1772+
/// }
1773+
/// ```
1774+
struct WarpOpScfIfOp : public WarpDistributionPattern {
1775+
WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1776+
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1777+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1778+
PatternRewriter &rewriter) const override {
1779+
gpu::YieldOp warpOpYield = warpOp.getTerminator();
1780+
// Only pick up `IfOp` if it is the last op in the region.
1781+
Operation *lastNode = warpOpYield->getPrevNode();
1782+
auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1783+
if (!ifOp)
1784+
return failure();
1785+
1786+
// The current `WarpOp` can yield two types of values:
1787+
// 1. Not results of `IfOp`:
1788+
// Preserve them in the new `WarpOp`.
1789+
// Collect their yield index to remap the usages.
1790+
// 2. Results of `IfOp`:
1791+
// They are not part of the new `WarpOp` results.
1792+
// Map current warp's yield operand index to `IfOp` result idx.
1793+
SmallVector<Value> nonIfYieldValues;
1794+
SmallVector<unsigned> nonIfYieldIndices;
1795+
llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1796+
llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1797+
for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1798+
const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
1799+
if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
1800+
nonIfYieldValues.push_back(yieldOperand.get());
1801+
nonIfYieldIndices.push_back(yieldOperandIdx);
1802+
continue;
1803+
}
1804+
OpResult ifResult = cast<OpResult>(yieldOperand.get());
1805+
const unsigned ifResultIdx = ifResult.getResultNumber();
1806+
ifResultMapping[yieldOperandIdx] = ifResultIdx;
1807+
// If this `ifOp` result is vector type and it is yielded by the
1808+
// `WarpOp`, we keep track the distributed type for this result.
1809+
if (!isa<VectorType>(ifResult.getType()))
1810+
continue;
1811+
VectorType distType =
1812+
cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1813+
ifResultDistTypes[ifResultIdx] = distType;
1814+
}
1815+
1816+
// Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1817+
// them
1818+
auto [escapingValuesThen, escapingValueInputTypesThen,
1819+
escapingValueDistTypesThen] =
1820+
getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1821+
distributionMapFn);
1822+
auto [escapingValuesElse, escapingValueInputTypesElse,
1823+
escapingValueDistTypesElse] =
1824+
getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1825+
distributionMapFn);
1826+
if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1827+
llvm::is_contained(escapingValueDistTypesElse, Type{}))
1828+
return failure();
1829+
1830+
// The new `WarpOp` groups yields values in following order:
1831+
// 1. Branch condition
1832+
// 2. Escaping values then branch
1833+
// 3. Escaping values else branch
1834+
// 4. All non-`ifOp` yielded values.
1835+
SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1836+
newWarpOpYieldValues.append(escapingValuesThen.begin(),
1837+
escapingValuesThen.end());
1838+
newWarpOpYieldValues.append(escapingValuesElse.begin(),
1839+
escapingValuesElse.end());
1840+
SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1841+
newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1842+
escapingValueDistTypesThen.end());
1843+
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1844+
escapingValueDistTypesElse.end());
1845+
1846+
llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
1847+
for (auto [idx, val] :
1848+
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1849+
origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
1850+
newWarpOpYieldValues.push_back(val);
1851+
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1852+
}
1853+
// Create the new `WarpOp` with the updated yield values and types.
1854+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1855+
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1856+
// `ifOp` returns the result of the inner warp op.
1857+
SmallVector<Type> newIfOpDistResTypes;
1858+
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1859+
Type distType = cast<Value>(res).getType();
1860+
if (auto vecType = dyn_cast<VectorType>(distType)) {
1861+
AffineMap map = distributionMapFn(cast<Value>(res));
1862+
// Fallback to affine map if the dist result was not previously recorded
1863+
distType = ifResultDistTypes.count(i)
1864+
? ifResultDistTypes[i]
1865+
: getDistributedType(vecType, map, warpOp.getWarpSize());
1866+
}
1867+
newIfOpDistResTypes.push_back(distType);
1868+
}
1869+
// Create a new `IfOp` outside the new `WarpOp` region.
1870+
OpBuilder::InsertionGuard g(rewriter);
1871+
rewriter.setInsertionPointAfter(newWarpOp);
1872+
auto newIfOp = scf::IfOp::create(
1873+
rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
1874+
static_cast<bool>(ifOp.thenBlock()),
1875+
static_cast<bool>(ifOp.elseBlock()));
1876+
auto encloseRegionInWarpOp =
1877+
[&](Block *oldIfBranch, Block *newIfBranch,
1878+
llvm::SmallSetVector<Value, 32> &escapingValues,
1879+
SmallVector<Type> &escapingValueInputTypes,
1880+
size_t warpResRangeStart) {
1881+
OpBuilder::InsertionGuard g(rewriter);
1882+
if (!newIfBranch)
1883+
return;
1884+
rewriter.setInsertionPointToStart(newIfBranch);
1885+
llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1886+
SmallVector<Value> innerWarpInputVals;
1887+
SmallVector<Type> innerWarpInputTypes;
1888+
for (size_t i = 0; i < escapingValues.size();
1889+
++i, ++warpResRangeStart) {
1890+
innerWarpInputVals.push_back(
1891+
newWarpOp.getResult(warpResRangeStart));
1892+
escapeValToBlockArgIndex[escapingValues[i]] =
1893+
innerWarpInputTypes.size();
1894+
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1895+
}
1896+
auto innerWarp = WarpExecuteOnLane0Op::create(
1897+
rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1898+
newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1899+
innerWarpInputVals, innerWarpInputTypes);
1900+
1901+
innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1902+
innerWarp.getWarpRegion().addArguments(
1903+
innerWarpInputTypes,
1904+
SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1905+
1906+
SmallVector<Value> yieldOperands;
1907+
for (Value operand : oldIfBranch->getTerminator()->getOperands())
1908+
yieldOperands.push_back(operand);
1909+
rewriter.eraseOp(oldIfBranch->getTerminator());
1910+
1911+
rewriter.setInsertionPointToEnd(innerWarp.getBody());
1912+
gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1913+
rewriter.setInsertionPointAfter(innerWarp);
1914+
scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1915+
1916+
// Update any users of escaping values that were forwarded to the
1917+
// inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1918+
innerWarp.walk([&](Operation *op) {
1919+
for (OpOperand &operand : op->getOpOperands()) {
1920+
auto it = escapeValToBlockArgIndex.find(operand.get());
1921+
if (it == escapeValToBlockArgIndex.end())
1922+
continue;
1923+
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1924+
}
1925+
});
1926+
mlir::vector::moveScalarUniformCode(innerWarp);
1927+
};
1928+
encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1929+
&newIfOp.getThenRegion().front(), escapingValuesThen,
1930+
escapingValueInputTypesThen, 1);
1931+
if (!ifOp.getElseRegion().empty())
1932+
encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1933+
&newIfOp.getElseRegion().front(),
1934+
escapingValuesElse, escapingValueInputTypesElse,
1935+
1 + escapingValuesThen.size());
1936+
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1937+
// result.
1938+
for (auto [origIdx, newIdx] : ifResultMapping)
1939+
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
1940+
newIfOp.getResult(newIdx), newIfOp);
1941+
// Similarly, update any users of the `WarpOp` results that were not
1942+
// results of the `IfOp`.
1943+
for (auto [origIdx, newIdx] : origToNewYieldIdx)
1944+
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
1945+
newWarpOp.getResult(newIdx));
1946+
// Remove the original `WarpOp` and `IfOp`, they should not have any uses
1947+
// at this point.
1948+
rewriter.eraseOp(ifOp);
1949+
rewriter.eraseOp(warpOp);
1950+
return success();
1951+
}
1952+
1953+
private:
1954+
DistributionMapFn distributionMapFn;
1955+
};
1956+
17161957
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
17171958
/// the scf.ForOp is the last operation in the region so that it doesn't
17181959
/// change the order of execution. This creates a new scf.for region after the
@@ -1759,25 +2000,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17592000
return failure();
17602001
// Collect Values that come from the `WarpOp` but are outside the `ForOp`.
17612002
// Those Values need to be returned by the new warp op.
1762-
llvm::SmallSetVector<Value, 32> escapingValues;
1763-
SmallVector<Type> escapingValueInputTypes;
1764-
SmallVector<Type> escapingValueDistTypes;
1765-
mlir::visitUsedValuesDefinedAbove(
1766-
forOp.getBodyRegion(), [&](OpOperand *operand) {
1767-
Operation *parent = operand->get().getParentRegion()->getParentOp();
1768-
if (warpOp->isAncestor(parent)) {
1769-
if (!escapingValues.insert(operand->get()))
1770-
return;
1771-
Type distType = operand->get().getType();
1772-
if (auto vecType = dyn_cast<VectorType>(distType)) {
1773-
AffineMap map = distributionMapFn(operand->get());
1774-
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1775-
}
1776-
escapingValueInputTypes.push_back(operand->get().getType());
1777-
escapingValueDistTypes.push_back(distType);
1778-
}
1779-
});
1780-
2003+
auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2004+
getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2005+
distributionMapFn);
17812006
if (llvm::is_contained(escapingValueDistTypes, Type{}))
17822007
return failure();
17832008
// `WarpOp` can yield two types of values:
@@ -2068,6 +2293,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
20682293
benefit);
20692294
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
20702295
benefit);
2296+
patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
2297+
benefit);
20712298
}
20722299

20732300
void mlir::vector::populateDistributeReduction(

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,3 +1856,72 @@ func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memre
18561856
// CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
18571857
// CHECK-PROP-NOT: vector.broadcast
18581858
// CHECK-PROP: vector.step : vector<64xindex>
1859+
1860+
// -----
1861+
1862+
func.func @warp_scf_if_no_yield_distribute(%buffer: memref<128xindex>, %pred : i1) {
1863+
%laneid = gpu.lane_id
1864+
%c0 = arith.constant 0 : index
1865+
1866+
gpu.warp_execute_on_lane_0(%laneid)[32] {
1867+
%seq = vector.step : vector<32xindex>
1868+
scf.if %pred {
1869+
vector.store %seq, %buffer[%c0] : memref<128xindex>, vector<32xindex>
1870+
}
1871+
gpu.yield
1872+
}
1873+
return
1874+
}
1875+
1876+
// CHECK-PROP-LABEL: func.func @warp_scf_if_no_yield_distribute(
1877+
// CHECK-PROP-SAME: %[[ARG0:.+]]: memref<128xindex>, %[[ARG1:.+]]: i1
1878+
// CHECK-PROP: scf.if %[[ARG1]] {
1879+
// CHECK-PROP: gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<1xindex>) {
1880+
// CHECK-PROP: ^bb0(%[[ARG2:.+]]: vector<32xindex>):
1881+
// CHECK-PROP: vector.store %[[ARG2]], %[[ARG0]][%{{.*}}] : memref<128xindex>, vector<32xindex>
1882+
1883+
// -----
1884+
1885+
func.func @warp_scf_if_distribute(%pred : i1) {
1886+
%laneid = gpu.lane_id
1887+
%c0 = arith.constant 0 : index
1888+
1889+
%0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> vector<1xf32> {
1890+
%seq1 = vector.step : vector<32xindex>
1891+
%seq2 = arith.constant dense<2> : vector<32xindex>
1892+
%0 = scf.if %pred -> (vector<32xf32>) {
1893+
%1 = "some_op"(%seq1) : (vector<32xindex>) -> (vector<32xf32>)
1894+
scf.yield %1 : vector<32xf32>
1895+
} else {
1896+
%2 = "other_op"(%seq2) : (vector<32xindex>) -> (vector<32xf32>)
1897+
scf.yield %2 : vector<32xf32>
1898+
}
1899+
gpu.yield %0 : vector<32xf32>
1900+
}
1901+
"some_use"(%0) : (vector<1xf32>) -> ()
1902+
1903+
return
1904+
}
1905+
1906+
// CHECK-PROP-LABEL: func.func @warp_scf_if_distribute(
1907+
// CHECK-PROP-SAME: %[[ARG0:.+]]: i1
1908+
// CHECK-PROP: %[[SEQ2:.+]] = arith.constant dense<2> : vector<32xindex>
1909+
// CHECK-PROP: %[[LANE_ID:.+]] = gpu.lane_id
1910+
// CHECK-PROP: %[[SEQ1:.+]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
1911+
// CHECK-PROP: %[[IF_YIELD_DIST:.+]] = scf.if %[[ARG0]] -> (vector<1xf32>) {
1912+
// CHECK-PROP: %[[THEN_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] args(%[[SEQ1]] : vector<1xindex>) -> (vector<1xf32>) {
1913+
// CHECK-PROP: ^bb0(%[[ARG1:.+]]: vector<32xindex>):
1914+
// CHECK-PROP: %{{.*}} = "some_op"(%[[ARG1]]) : (vector<32xindex>) -> vector<32xf32>
1915+
// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
1916+
// CHECK-PROP: }
1917+
// CHECK-PROP: scf.yield %[[THEN_DIST]] : vector<1xf32>
1918+
// CHECK-PROP: } else {
1919+
// CHECK-PROP: %[[ELSE_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] -> (vector<1xf32>) {
1920+
// CHECK-PROP: %{{.*}} = "other_op"(%[[SEQ2]]) : (vector<32xindex>) -> vector<32xf32>
1921+
// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
1922+
// CHECK-PROP: }
1923+
// CHECK-PROP: scf.yield %[[ELSE_DIST]] : vector<1xf32>
1924+
// CHECK-PROP: }
1925+
// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
1926+
// CHECK-PROP: return
1927+
// CHECK-PROP: }

0 commit comments

Comments
 (0)