Skip to content

Commit 4b202de

Browse files
committed
[MLIR][Vector] Add warp distribution for scf.if
1 parent dc85d0c commit 4b202de

File tree

2 files changed

+270
-0
lines changed

2 files changed

+270
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,205 @@ struct WarpOpInsert : public WarpDistributionPattern {
17131713
}
17141714
};
17151715

1716+
struct WarpOpScfIfOp : public WarpDistributionPattern {
1717+
WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1718+
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1719+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1720+
PatternRewriter &rewriter) const override {
1721+
gpu::YieldOp warpOpYield = warpOp.getTerminator();
1722+
// Only pick up `IfOp` if it is the last op in the region.
1723+
Operation *lastNode = warpOpYield->getPrevNode();
1724+
auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1725+
if (!ifOp)
1726+
return failure();
1727+
1728+
// The current `WarpOp` can yield two types of values:
1729+
// 1. Not results of `IfOp`:
1730+
// Preserve them in the new `WarpOp`.
1731+
// Collect their yield index.
1732+
// 2. Results of `IfOp`:
1733+
// They are not part of the new `WarpOp` results.
1734+
// Map current warp's yield operand index to `IfOp` result idx.
1735+
SmallVector<Value> nonIfYieldValues;
1736+
SmallVector<unsigned> nonIfYieldIndices;
1737+
llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1738+
llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1739+
for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1740+
const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
1741+
if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
1742+
nonIfYieldValues.push_back(yieldOperand.get());
1743+
nonIfYieldIndices.push_back(yieldOperandIdx);
1744+
continue;
1745+
}
1746+
OpResult ifResult = cast<OpResult>(yieldOperand.get());
1747+
const unsigned ifResultIdx = ifResult.getResultNumber();
1748+
ifResultMapping[yieldOperandIdx] = ifResultIdx;
1749+
// If this `ifOp` result is vector type and it is yielded by the
1750+
// `WarpOp`, we keep track the distributed type for this result.
1751+
if (!isa<VectorType>(ifResult.getType()))
1752+
continue;
1753+
VectorType distType =
1754+
cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1755+
ifResultDistTypes[ifResultIdx] = distType;
1756+
}
1757+
1758+
// Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1759+
// them
1760+
auto getEscapingValues = [&](Region &branch,
1761+
llvm::SmallSetVector<Value, 32> &values,
1762+
SmallVector<Type> &inputTypes,
1763+
SmallVector<Type> &distTypes) {
1764+
if (branch.empty())
1765+
return;
1766+
mlir::visitUsedValuesDefinedAbove(branch, [&](OpOperand *operand) {
1767+
Operation *parent = operand->get().getParentRegion()->getParentOp();
1768+
if (warpOp->isAncestor(parent)) {
1769+
if (!values.insert(operand->get()))
1770+
return;
1771+
Type distType = operand->get().getType();
1772+
if (auto vecType = dyn_cast<VectorType>(distType)) {
1773+
AffineMap map = distributionMapFn(operand->get());
1774+
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1775+
}
1776+
inputTypes.push_back(operand->get().getType());
1777+
distTypes.push_back(distType);
1778+
}
1779+
});
1780+
};
1781+
llvm::SmallSetVector<Value, 32> escapingValuesThen;
1782+
SmallVector<Type> escapingValueInputTypesThen; // inner warp op block args
1783+
SmallVector<Type> escapingValueDistTypesThen; // new warp returns
1784+
getEscapingValues(ifOp.getThenRegion(), escapingValuesThen,
1785+
escapingValueInputTypesThen, escapingValueDistTypesThen);
1786+
llvm::SmallSetVector<Value, 32> escapingValuesElse;
1787+
SmallVector<Type> escapingValueInputTypesElse; // inner warp op block args
1788+
SmallVector<Type> escapingValueDistTypesElse; // new warp returns
1789+
getEscapingValues(ifOp.getElseRegion(), escapingValuesElse,
1790+
escapingValueInputTypesElse, escapingValueDistTypesElse);
1791+
1792+
if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1793+
llvm::is_contained(escapingValueDistTypesElse, Type{}))
1794+
return failure();
1795+
1796+
// The new `WarpOp` groups yields values in following order:
1797+
// 1. Escaping values then branch
1798+
// 2. Escaping values else branch
1799+
// 3. All non-`ifOp` yielded values.
1800+
SmallVector<Value> newWarpOpYieldValues{escapingValuesThen.begin(),
1801+
escapingValuesThen.end()};
1802+
newWarpOpYieldValues.append(escapingValuesElse.begin(),
1803+
escapingValuesElse.end());
1804+
SmallVector<Type> newWarpOpDistTypes = escapingValueDistTypesThen;
1805+
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1806+
escapingValueDistTypesElse.end());
1807+
1808+
llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
1809+
for (auto [idx, val] :
1810+
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1811+
origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
1812+
newWarpOpYieldValues.push_back(val);
1813+
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1814+
}
1815+
// Create the new `WarpOp` with the updated yield values and types.
1816+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1817+
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1818+
1819+
// `ifOp` returns the result of the inner warp op.
1820+
SmallVector<Type> newIfOpDistResTypes;
1821+
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1822+
Type distType = cast<Value>(res).getType();
1823+
if (auto vecType = dyn_cast<VectorType>(distType)) {
1824+
AffineMap map = distributionMapFn(cast<Value>(res));
1825+
distType = ifResultDistTypes.count(i)
1826+
? ifResultDistTypes[i]
1827+
: getDistributedType(vecType, map, warpOp.getWarpSize());
1828+
}
1829+
newIfOpDistResTypes.push_back(distType);
1830+
}
1831+
// Create a new `IfOp` outside the new `WarpOp` region.
1832+
OpBuilder::InsertionGuard g(rewriter);
1833+
rewriter.setInsertionPointAfter(newWarpOp);
1834+
auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(),
1835+
newIfOpDistResTypes, ifOp.getCondition(),
1836+
static_cast<bool>(ifOp.thenBlock()),
1837+
static_cast<bool>(ifOp.elseBlock()));
1838+
1839+
auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
1840+
llvm::SmallSetVector<Value, 32> &escapingValues,
1841+
SmallVector<Type> &escapingValueInputTypes) {
1842+
OpBuilder::InsertionGuard g(rewriter);
1843+
if (!newIfBranch)
1844+
return;
1845+
rewriter.setInsertionPointToStart(newIfBranch);
1846+
llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1847+
SmallVector<Value> innerWarpInputVals;
1848+
SmallVector<Type> innerWarpInputTypes;
1849+
for (size_t i = 0; i < escapingValues.size(); ++i) {
1850+
innerWarpInputVals.push_back(newWarpOp.getResult(i));
1851+
escapeValToBlockArgIndex[escapingValues[i]] =
1852+
innerWarpInputTypes.size();
1853+
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1854+
}
1855+
auto innerWarp = WarpExecuteOnLane0Op::create(
1856+
rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1857+
newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInputVals,
1858+
innerWarpInputTypes);
1859+
1860+
innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1861+
innerWarp.getWarpRegion().addArguments(
1862+
innerWarpInputTypes,
1863+
SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1864+
1865+
SmallVector<Value> yieldOperands;
1866+
for (Value operand : oldIfBranch->getTerminator()->getOperands())
1867+
yieldOperands.push_back(operand);
1868+
rewriter.eraseOp(oldIfBranch->getTerminator());
1869+
1870+
rewriter.setInsertionPointToEnd(innerWarp.getBody());
1871+
gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1872+
rewriter.setInsertionPointAfter(innerWarp);
1873+
scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1874+
1875+
// Update any users of escaping values that were forwarded to the
1876+
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
1877+
innerWarp.walk([&](Operation *op) {
1878+
for (OpOperand &operand : op->getOpOperands()) {
1879+
auto it = escapeValToBlockArgIndex.find(operand.get());
1880+
if (it == escapeValToBlockArgIndex.end())
1881+
continue;
1882+
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1883+
}
1884+
});
1885+
mlir::vector::moveScalarUniformCode(innerWarp);
1886+
};
1887+
processBranch(&ifOp.getThenRegion().front(),
1888+
&newIfOp.getThenRegion().front(), escapingValuesThen,
1889+
escapingValueInputTypesThen);
1890+
if (!ifOp.getElseRegion().empty())
1891+
processBranch(&ifOp.getElseRegion().front(),
1892+
&newIfOp.getElseRegion().front(), escapingValuesElse,
1893+
escapingValueInputTypesElse);
1894+
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1895+
// result.
1896+
for (auto [origIdx, newIdx] : ifResultMapping)
1897+
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
1898+
newIfOp.getResult(newIdx), newIfOp);
1899+
// Similarly, update any users of the `WarpOp` results that were not
1900+
// results of the `IfOp`.
1901+
for (auto [origIdx, newIdx] : origToNewYieldIdx)
1902+
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
1903+
newWarpOp.getResult(newIdx));
1904+
// Remove the original `WarpOp` and `IfOp`, they should not have any uses
1905+
// at this point.
1906+
rewriter.eraseOp(ifOp);
1907+
rewriter.eraseOp(warpOp);
1908+
return success();
1909+
}
1910+
1911+
private:
1912+
DistributionMapFn distributionMapFn;
1913+
};
1914+
17161915
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
17171916
/// the scf.ForOp is the last operation in the region so that it doesn't
17181917
/// change the order of execution. This creates a new scf.for region after the
@@ -2068,6 +2267,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
20682267
benefit);
20692268
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
20702269
benefit);
2270+
patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
2271+
benefit);
20712272
}
20722273

20732274
void mlir::vector::populateDistributeReduction(

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,3 +1856,72 @@ func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memre
18561856
// CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
18571857
// CHECK-PROP-NOT: vector.broadcast
18581858
// CHECK-PROP: vector.step : vector<64xindex>
1859+
1860+
// -----
1861+
1862+
func.func @warp_scf_if_no_yield_distribute(%buffer: memref<128xindex>, %pred : i1) {
1863+
%laneid = gpu.lane_id
1864+
%c0 = arith.constant 0 : index
1865+
1866+
gpu.warp_execute_on_lane_0(%laneid)[32] {
1867+
%seq = vector.step : vector<32xindex>
1868+
scf.if %pred {
1869+
vector.store %seq, %buffer[%c0] : memref<128xindex>, vector<32xindex>
1870+
}
1871+
gpu.yield
1872+
}
1873+
return
1874+
}
1875+
1876+
// CHECK-PROP-LABEL: func.func @warp_scf_if_no_yield_distribute(
1877+
// CHECK-PROP-SAME: %[[ARG0:.+]]: memref<128xindex>, %[[ARG1:.+]]: i1
1878+
// CHECK-PROP: scf.if %[[ARG1]] {
1879+
// CHECK-PROP: gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<1xindex>) {
1880+
// CHECK-PROP: ^bb0(%[[ARG2:.+]]: vector<32xindex>):
1881+
// CHECK-PROP: vector.store %[[ARG2]], %[[ARG0]][%{{.*}}] : memref<128xindex>, vector<32xindex>
1882+
1883+
// -----
1884+
1885+
func.func @warp_scf_if_distribute(%pred : i1) {
1886+
%laneid = gpu.lane_id
1887+
%c0 = arith.constant 0 : index
1888+
1889+
%0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> vector<1xf32> {
1890+
%seq1 = vector.step : vector<32xindex>
1891+
%seq2 = arith.constant dense<2> : vector<32xindex>
1892+
%0 = scf.if %pred -> (vector<32xf32>) {
1893+
%1 = "some_op"(%seq1) : (vector<32xindex>) -> (vector<32xf32>)
1894+
scf.yield %1 : vector<32xf32>
1895+
} else {
1896+
%2 = "other_op"(%seq2) : (vector<32xindex>) -> (vector<32xf32>)
1897+
scf.yield %2 : vector<32xf32>
1898+
}
1899+
gpu.yield %0 : vector<32xf32>
1900+
}
1901+
"some_use"(%0) : (vector<1xf32>) -> ()
1902+
1903+
return
1904+
}
1905+
1906+
// CHECK-PROP-LABEL: func.func @warp_scf_if_distribute(
1907+
// CHECK-PROP-SAME: %[[ARG0:.+]]: i1
1908+
// CHECK-PROP: %[[SEQ2:.+]] = arith.constant dense<2> : vector<32xindex>
1909+
// CHECK-PROP: %[[LANE_ID:.+]] = gpu.lane_id
1910+
// CHECK-PROP: %[[SEQ1:.+]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
1911+
// CHECK-PROP: %[[IF_YIELD_DIST:.+]] = scf.if %[[ARG0]] -> (vector<1xf32>) {
1912+
// CHECK-PROP: %[[THEN_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] args(%[[SEQ1]] : vector<1xindex>) -> (vector<1xf32>) {
1913+
// CHECK-PROP: ^bb0(%[[ARG1:.+]]: vector<32xindex>):
1914+
// CHECK-PROP: %{{.*}} = "some_op"(%[[ARG1]]) : (vector<32xindex>) -> vector<32xf32>
1915+
// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
1916+
// CHECK-PROP: }
1917+
// CHECK-PROP: scf.yield %[[THEN_DIST]] : vector<1xf32>
1918+
// CHECK-PROP: } else {
1919+
// CHECK-PROP: %[[ELSE_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] -> (vector<1xf32>) {
1920+
// CHECK-PROP: %{{.*}} = "other_op"(%[[SEQ2]]) : (vector<32xindex>) -> vector<32xf32>
1921+
// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
1922+
// CHECK-PROP: }
1923+
// CHECK-PROP: scf.yield %[[ELSE_DIST]] : vector<1xf32>
1924+
// CHECK-PROP: }
1925+
// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
1926+
// CHECK-PROP: return
1927+
// CHECK-PROP: }

0 commit comments

Comments
 (0)