Skip to content

Commit 784dda1

Browse files
committed
Address feedback
1 parent 7d31574 commit 784dda1

File tree

1 file changed

+125
-89
lines changed

1 file changed

+125
-89
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 125 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,36 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
371371
return targetType;
372372
}
373373

374+
/// Given a warpOp that contains ops with regions, the corresponding op's
375+
/// "inner" region and the distributionMapFn, get all values used by the op's
376+
/// region that are defined within the warpOp. Return the set of values, their
377+
/// types and their distributed types.
378+
std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
379+
SmallVector<Type>>
380+
getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
381+
DistributionMapFn distributionMapFn) {
382+
llvm::SmallSetVector<Value, 32> escapingValues;
383+
SmallVector<Type> escapingValueTypes;
384+
SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
385+
if (innerRegion.empty())
386+
return {escapingValues, escapingValueTypes, escapingValueDistTypes};
387+
mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
388+
Operation *parent = operand->get().getParentRegion()->getParentOp();
389+
if (warpOp->isAncestor(parent)) {
390+
if (!escapingValues.insert(operand->get()))
391+
return;
392+
Type distType = operand->get().getType();
393+
if (auto vecType = dyn_cast<VectorType>(distType)) {
394+
AffineMap map = distributionMapFn(operand->get());
395+
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
396+
}
397+
escapingValueTypes.push_back(operand->get().getType());
398+
escapingValueDistTypes.push_back(distType);
399+
}
400+
});
401+
return {escapingValues, escapingValueTypes, escapingValueDistTypes};
402+
}
403+
374404
/// Distribute transfer_write ops based on the affine map returned by
375405
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
376406
/// will not be distributed (it should be less than the warp size).
@@ -1713,6 +1743,32 @@ struct WarpOpInsert : public WarpDistributionPattern {
17131743
}
17141744
};
17151745

1746+
/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1747+
/// the scf.if is the last operation in the region so that it doesn't
1748+
/// change the order of execution. This creates a new scf.if after the
1749+
/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1750+
/// the "inner" WarpExecuteOnLane0Op. Example:
1751+
/// ```
1752+
/// gpu.warp_execute_on_lane_0(%laneid)[32] {
1753+
/// %payload = ... : vector<32xindex>
1754+
/// scf.if %pred {
1755+
/// vector.store %payload, %buffer[%idx] : memref<128xindex>,
1756+
/// vector<32xindex>
1757+
/// }
1758+
/// gpu.yield
1759+
/// }
1760+
/// ```
1761+
/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1762+
/// %payload = ... : vector<32xindex>
1763+
/// gpu.yield %payload : vector<32xindex>
1764+
/// }
1765+
/// scf.if %pred {
1766+
/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1767+
/// ^bb0(%arg1: vector<32xindex>):
1768+
/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1769+
/// }
1770+
/// }
1771+
/// ```
17161772
struct WarpOpScfIfOp : public WarpDistributionPattern {
17171773
WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
17181774
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
@@ -1728,7 +1784,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
17281784
// The current `WarpOp` can yield two types of values:
17291785
// 1. Not results of `IfOp`:
17301786
// Preserve them in the new `WarpOp`.
1731-
// Collect their yield index.
1787+
// Collect their yield index to remap the usages.
17321788
// 2. Results of `IfOp`:
17331789
// They are not part of the new `WarpOp` results.
17341790
// Map current warp's yield operand index to `IfOp` result idx.
@@ -1757,38 +1813,14 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
17571813

17581814
// Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
17591815
// them
1760-
auto getEscapingValues = [&](Region &branch,
1761-
llvm::SmallSetVector<Value, 32> &values,
1762-
SmallVector<Type> &inputTypes,
1763-
SmallVector<Type> &distTypes) {
1764-
if (branch.empty())
1765-
return;
1766-
mlir::visitUsedValuesDefinedAbove(branch, [&](OpOperand *operand) {
1767-
Operation *parent = operand->get().getParentRegion()->getParentOp();
1768-
if (warpOp->isAncestor(parent)) {
1769-
if (!values.insert(operand->get()))
1770-
return;
1771-
Type distType = operand->get().getType();
1772-
if (auto vecType = dyn_cast<VectorType>(distType)) {
1773-
AffineMap map = distributionMapFn(operand->get());
1774-
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1775-
}
1776-
inputTypes.push_back(operand->get().getType());
1777-
distTypes.push_back(distType);
1778-
}
1779-
});
1780-
};
1781-
llvm::SmallSetVector<Value, 32> escapingValuesThen;
1782-
SmallVector<Type> escapingValueInputTypesThen; // inner warp op block args
1783-
SmallVector<Type> escapingValueDistTypesThen; // new warp returns
1784-
getEscapingValues(ifOp.getThenRegion(), escapingValuesThen,
1785-
escapingValueInputTypesThen, escapingValueDistTypesThen);
1786-
llvm::SmallSetVector<Value, 32> escapingValuesElse;
1787-
SmallVector<Type> escapingValueInputTypesElse; // inner warp op block args
1788-
SmallVector<Type> escapingValueDistTypesElse; // new warp returns
1789-
getEscapingValues(ifOp.getElseRegion(), escapingValuesElse,
1790-
escapingValueInputTypesElse, escapingValueDistTypesElse);
1791-
1816+
auto [escapingValuesThen, escapingValueInputTypesThen,
1817+
escapingValueDistTypesThen] =
1818+
getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1819+
distributionMapFn);
1820+
auto [escapingValuesElse, escapingValueInputTypesElse,
1821+
escapingValueDistTypesElse] =
1822+
getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1823+
distributionMapFn);
17921824
if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
17931825
llvm::is_contained(escapingValueDistTypesElse, Type{}))
17941826
return failure();
@@ -1825,6 +1857,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18251857
Type distType = cast<Value>(res).getType();
18261858
if (auto vecType = dyn_cast<VectorType>(distType)) {
18271859
AffineMap map = distributionMapFn(cast<Value>(res));
1860+
// Fallback to affine map if the dist result was not previously recorded
18281861
distType = ifResultDistTypes.count(i)
18291862
? ifResultDistTypes[i]
18301863
: getDistributedType(vecType, map, warpOp.getWarpSize());
@@ -1838,63 +1871,66 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18381871
rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
18391872
static_cast<bool>(ifOp.thenBlock()),
18401873
static_cast<bool>(ifOp.elseBlock()));
1841-
1842-
auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
1843-
llvm::SmallSetVector<Value, 32> &escapingValues,
1844-
SmallVector<Type> &escapingValueInputTypes,
1845-
size_t warpResRangeStart) {
1846-
OpBuilder::InsertionGuard g(rewriter);
1847-
if (!newIfBranch)
1848-
return;
1849-
rewriter.setInsertionPointToStart(newIfBranch);
1850-
llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1851-
SmallVector<Value> innerWarpInputVals;
1852-
SmallVector<Type> innerWarpInputTypes;
1853-
for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) {
1854-
innerWarpInputVals.push_back(newWarpOp.getResult(warpResRangeStart));
1855-
escapeValToBlockArgIndex[escapingValues[i]] =
1856-
innerWarpInputTypes.size();
1857-
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1858-
}
1859-
auto innerWarp = WarpExecuteOnLane0Op::create(
1860-
rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1861-
newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInputVals,
1862-
innerWarpInputTypes);
1863-
1864-
innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1865-
innerWarp.getWarpRegion().addArguments(
1866-
innerWarpInputTypes,
1867-
SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1868-
1869-
SmallVector<Value> yieldOperands;
1870-
for (Value operand : oldIfBranch->getTerminator()->getOperands())
1871-
yieldOperands.push_back(operand);
1872-
rewriter.eraseOp(oldIfBranch->getTerminator());
1873-
1874-
rewriter.setInsertionPointToEnd(innerWarp.getBody());
1875-
gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1876-
rewriter.setInsertionPointAfter(innerWarp);
1877-
scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1878-
1879-
// Update any users of escaping values that were forwarded to the
1880-
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
1881-
innerWarp.walk([&](Operation *op) {
1882-
for (OpOperand &operand : op->getOpOperands()) {
1883-
auto it = escapeValToBlockArgIndex.find(operand.get());
1884-
if (it == escapeValToBlockArgIndex.end())
1885-
continue;
1886-
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1887-
}
1888-
});
1889-
mlir::vector::moveScalarUniformCode(innerWarp);
1890-
};
1891-
processBranch(&ifOp.getThenRegion().front(),
1892-
&newIfOp.getThenRegion().front(), escapingValuesThen,
1893-
escapingValueInputTypesThen, 1);
1874+
auto encloseRegionInWarpOp =
1875+
[&](Block *oldIfBranch, Block *newIfBranch,
1876+
llvm::SmallSetVector<Value, 32> &escapingValues,
1877+
SmallVector<Type> &escapingValueInputTypes,
1878+
size_t warpResRangeStart) {
1879+
OpBuilder::InsertionGuard g(rewriter);
1880+
if (!newIfBranch)
1881+
return;
1882+
rewriter.setInsertionPointToStart(newIfBranch);
1883+
llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1884+
SmallVector<Value> innerWarpInputVals;
1885+
SmallVector<Type> innerWarpInputTypes;
1886+
for (size_t i = 0; i < escapingValues.size();
1887+
++i, ++warpResRangeStart) {
1888+
innerWarpInputVals.push_back(
1889+
newWarpOp.getResult(warpResRangeStart));
1890+
escapeValToBlockArgIndex[escapingValues[i]] =
1891+
innerWarpInputTypes.size();
1892+
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1893+
}
1894+
auto innerWarp = WarpExecuteOnLane0Op::create(
1895+
rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1896+
newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1897+
innerWarpInputVals, innerWarpInputTypes);
1898+
1899+
innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1900+
innerWarp.getWarpRegion().addArguments(
1901+
innerWarpInputTypes,
1902+
SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1903+
1904+
SmallVector<Value> yieldOperands;
1905+
for (Value operand : oldIfBranch->getTerminator()->getOperands())
1906+
yieldOperands.push_back(operand);
1907+
rewriter.eraseOp(oldIfBranch->getTerminator());
1908+
1909+
rewriter.setInsertionPointToEnd(innerWarp.getBody());
1910+
gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1911+
rewriter.setInsertionPointAfter(innerWarp);
1912+
scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1913+
1914+
// Update any users of escaping values that were forwarded to the
1915+
// inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1916+
innerWarp.walk([&](Operation *op) {
1917+
for (OpOperand &operand : op->getOpOperands()) {
1918+
auto it = escapeValToBlockArgIndex.find(operand.get());
1919+
if (it == escapeValToBlockArgIndex.end())
1920+
continue;
1921+
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1922+
}
1923+
});
1924+
mlir::vector::moveScalarUniformCode(innerWarp);
1925+
};
1926+
encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1927+
&newIfOp.getThenRegion().front(), escapingValuesThen,
1928+
escapingValueInputTypesThen, 1);
18941929
if (!ifOp.getElseRegion().empty())
1895-
processBranch(&ifOp.getElseRegion().front(),
1896-
&newIfOp.getElseRegion().front(), escapingValuesElse,
1897-
escapingValueInputTypesElse, 1 + escapingValuesThen.size());
1930+
encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1931+
&newIfOp.getElseRegion().front(),
1932+
escapingValuesElse, escapingValueInputTypesElse,
1933+
1 + escapingValuesThen.size());
18981934
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
18991935
// result.
19001936
for (auto [origIdx, newIdx] : ifResultMapping)

0 commit comments

Comments
 (0)