Skip to content

Commit 244ebef

Browse files
authored
Reapply [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (#148313)
Reapply attempt for : #148291 Fix for the build failure reported in : https://lab.llvm.org/buildbot/#/builders/116/builds/15477 ----- This crash is caused by mismatch of distributed type returned by `getDistributedType` and intended distributed type for forOp results. Solution diff: 20c2cf6 Example: ``` func.func @warp_scf_for_broadcasted_result(%arg0: index) -> vector<1xf32> { %c128 = arith.constant 128 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %2 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>) { %ini = "some_def"() : () -> (vector<1xf32>) %0 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<1xf32>) { %1 = "some_op"(%arg4) : (vector<1xf32>) -> (vector<1xf32>) scf.yield %1 : vector<1xf32> } gpu.yield %0 : vector<1xf32> } return %2 : vector<1xf32> } ``` In this case the distributed type for forOp result is `vector<1xf32>` (result is not distributed and broadcasted to all lanes instead). However, in this case `getDistributedType` will return NULL type. Therefore, if the distributed type can be recovered from warpOp, we should always do that first before using `getDistributedType`
1 parent 5277021 commit 244ebef

File tree

3 files changed

+265
-54
lines changed

3 files changed

+265
-54
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 137 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,19 +1704,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17041704
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
17051705
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
17061706
PatternRewriter &rewriter) const override {
1707-
auto yield = cast<gpu::YieldOp>(
1707+
auto warpOpYield = cast<gpu::YieldOp>(
17081708
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1709-
// Only pick up forOp if it is the last op in the region.
1710-
Operation *lastNode = yield->getPrevNode();
1709+
// Only pick up `ForOp` if it is the last op in the region.
1710+
Operation *lastNode = warpOpYield->getPrevNode();
17111711
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17121712
if (!forOp)
17131713
return failure();
1714-
// Collect Values that come from the warp op but are outside the forOp.
1715-
// Those Value needs to be returned by the original warpOp and passed to
1716-
// the new op.
1714+
// Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1715+
// Those Values need to be returned by the new warp op.
17171716
llvm::SmallSetVector<Value, 32> escapingValues;
1718-
SmallVector<Type> inputTypes;
1719-
SmallVector<Type> distTypes;
1717+
SmallVector<Type> escapingValueInputTypes;
1718+
SmallVector<Type> escapingValueDistTypes;
17201719
mlir::visitUsedValuesDefinedAbove(
17211720
forOp.getBodyRegion(), [&](OpOperand *operand) {
17221721
Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1728,81 +1727,168 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17281727
AffineMap map = distributionMapFn(operand->get());
17291728
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
17301729
}
1731-
inputTypes.push_back(operand->get().getType());
1732-
distTypes.push_back(distType);
1730+
escapingValueInputTypes.push_back(operand->get().getType());
1731+
escapingValueDistTypes.push_back(distType);
17331732
}
17341733
});
17351734

1736-
if (llvm::is_contained(distTypes, Type{}))
1735+
if (llvm::is_contained(escapingValueDistTypes, Type{}))
17371736
return failure();
1738-
1739-
SmallVector<size_t> newRetIndices;
1740-
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1741-
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1742-
newRetIndices);
1743-
yield = cast<gpu::YieldOp>(
1744-
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1745-
1746-
SmallVector<Value> newOperands;
1747-
SmallVector<unsigned> resultIdx;
1748-
// Collect all the outputs coming from the forOp.
1749-
for (OpOperand &yieldOperand : yield->getOpOperands()) {
1750-
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1737+
// `WarpOp` can yield two types of values:
1738+
// 1. Values that are not results of the `ForOp`:
1739+
// These values must also be yielded by the new `WarpOp`. Also, we need
1740+
// to record the index mapping for these values to replace them later.
1741+
// 2. Values that are results of the `ForOp`:
1742+
// In this case, we record the index mapping between the `WarpOp` result
1743+
// index and matching `ForOp` result index.
1744+
// Additionally, we keep track of the distributed types for all `ForOp`
1745+
// vector results.
1746+
SmallVector<Value> nonForYieldedValues;
1747+
SmallVector<unsigned> nonForResultIndices;
1748+
llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
1749+
llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
1750+
for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1751+
// Yielded value is not a result of the forOp.
1752+
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
1753+
nonForYieldedValues.push_back(yieldOperand.get());
1754+
nonForResultIndices.push_back(yieldOperand.getOperandNumber());
17511755
continue;
1752-
auto forResult = cast<OpResult>(yieldOperand.get());
1753-
newOperands.push_back(
1754-
newWarpOp.getResult(yieldOperand.getOperandNumber()));
1755-
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1756-
resultIdx.push_back(yieldOperand.getOperandNumber());
1756+
}
1757+
OpResult forResult = cast<OpResult>(yieldOperand.get());
1758+
unsigned int forResultNumber = forResult.getResultNumber();
1759+
forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
1760+
// If this `ForOp` result is vector type and it is yielded by the
1761+
// `WarpOp`, we keep track the distributed type for this result.
1762+
if (!isa<VectorType>(forResult.getType()))
1763+
continue;
1764+
VectorType distType = cast<VectorType>(
1765+
warpOp.getResult(yieldOperand.getOperandNumber()).getType());
1766+
forResultDistTypes[forResultNumber] = distType;
17571767
}
17581768

1769+
// Newly created `WarpOp` will yield values in following order:
1770+
// 1. All init args of the `ForOp`.
1771+
// 2. All escaping values.
1772+
// 3. All non-`ForOp` yielded values.
1773+
SmallVector<Value> newWarpOpYieldValues;
1774+
SmallVector<Type> newWarpOpDistTypes;
1775+
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
1776+
newWarpOpYieldValues.push_back(initArg);
1777+
// Compute the distributed type for this init arg.
1778+
Type distType = initArg.getType();
1779+
if (auto vecType = dyn_cast<VectorType>(distType)) {
1780+
// If the `ForOp` result corresponds to this init arg is already yielded
1781+
// we can get the distributed type from `forResultDistTypes` map.
1782+
// Otherwise, we compute it using distributionMapFn.
1783+
AffineMap map = distributionMapFn(initArg);
1784+
distType = forResultDistTypes.count(i)
1785+
? forResultDistTypes[i]
1786+
: getDistributedType(vecType, map, warpOp.getWarpSize());
1787+
}
1788+
newWarpOpDistTypes.push_back(distType);
1789+
}
1790+
// Insert escaping values and their distributed types.
1791+
newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
1792+
escapingValues.begin(), escapingValues.end());
1793+
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
1794+
escapingValueDistTypes.begin(),
1795+
escapingValueDistTypes.end());
1796+
// Next, we insert all non-`ForOp` yielded values and their distributed
1797+
// types. We also create a mapping between the non-`ForOp` yielded value
1798+
// index and the corresponding new `WarpOp` yield value index (needed to
1799+
// update users later).
1800+
llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
1801+
for (auto [i, v] :
1802+
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
1803+
nonForResultMapping[i] = newWarpOpYieldValues.size();
1804+
newWarpOpYieldValues.push_back(v);
1805+
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
1806+
}
1807+
// Create the new `WarpOp` with the updated yield values and types.
1808+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1809+
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1810+
1811+
// Next, we create a new `ForOp` with the init args yielded by the new
1812+
// `WarpOp`.
1813+
const unsigned escapingValuesStartIdx =
1814+
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
1815+
// escaping values in the new `WarpOp`.
1816+
SmallVector<Value> newForOpOperands;
1817+
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
1818+
newForOpOperands.push_back(newWarpOp.getResult(i));
1819+
1820+
// Create a new `ForOp` outside the new `WarpOp` region.
17591821
OpBuilder::InsertionGuard g(rewriter);
17601822
rewriter.setInsertionPointAfter(newWarpOp);
1761-
1762-
// Create a new for op outside the region with a WarpExecuteOnLane0Op
1763-
// region inside.
17641823
auto newForOp = rewriter.create<scf::ForOp>(
17651824
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1766-
forOp.getStep(), newOperands);
1825+
forOp.getStep(), newForOpOperands);
1826+
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1827+
// newly created `ForOp`. This `WarpOp` will contain all ops that were
1828+
// contained within the original `ForOp` body.
17671829
rewriter.setInsertionPointToStart(newForOp.getBody());
17681830

1769-
SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1770-
newForOp.getRegionIterArgs().end());
1771-
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1772-
forOp.getResultTypes().end());
1831+
SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
1832+
newForOp.getRegionIterArgs().end());
1833+
SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
1834+
forOp.getResultTypes().end());
1835+
// Escaping values are forwarded to the inner `WarpOp` as its (additional)
1836+
// arguments. We keep track of the mapping between these values and their
1837+
// argument index in the inner `WarpOp` (to replace users later).
17731838
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1774-
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1775-
warpInput.push_back(newWarpOp.getResult(retIdx));
1776-
argIndexMapping[escapingValues[i]] = warpInputType.size();
1777-
warpInputType.push_back(inputTypes[i]);
1839+
for (size_t i = escapingValuesStartIdx;
1840+
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
1841+
innerWarpInput.push_back(newWarpOp.getResult(i));
1842+
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1843+
innerWarpInputType.size();
1844+
innerWarpInputType.push_back(
1845+
escapingValueInputTypes[i - escapingValuesStartIdx]);
17781846
}
1847+
// Create the inner `WarpOp` with the new input values and types.
17791848
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
17801849
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1781-
newWarpOp.getWarpSize(), warpInput, warpInputType);
1850+
newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
17821851

1852+
// Inline the `ForOp` body into the inner `WarpOp` body.
17831853
SmallVector<Value> argMapping;
17841854
argMapping.push_back(newForOp.getInductionVar());
1785-
for (Value args : innerWarp.getBody()->getArguments()) {
1855+
for (Value args : innerWarp.getBody()->getArguments())
17861856
argMapping.push_back(args);
1787-
}
1857+
17881858
argMapping.resize(forOp.getBody()->getNumArguments());
17891859
SmallVector<Value> yieldOperands;
17901860
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
17911861
yieldOperands.push_back(operand);
1862+
17921863
rewriter.eraseOp(forOp.getBody()->getTerminator());
17931864
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1865+
1866+
// Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1867+
// original `ForOp` results.
17941868
rewriter.setInsertionPointToEnd(innerWarp.getBody());
17951869
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
17961870
rewriter.setInsertionPointAfter(innerWarp);
1871+
// Insert a scf.yield op at the end of the new `ForOp` body that yields
1872+
// the inner `WarpOp` results.
17971873
if (!innerWarp.getResults().empty())
17981874
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1875+
1876+
// Update the users of original `WarpOp` results that were coming from the
1877+
// original `ForOp` to the corresponding new `ForOp` result.
1878+
for (auto [origIdx, newIdx] : forResultMapping)
1879+
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
1880+
newForOp.getResult(newIdx), newForOp);
1881+
// Similarly, update any users of the `WarpOp` results that were not
1882+
// results of the `ForOp`.
1883+
for (auto [origIdx, newIdx] : nonForResultMapping)
1884+
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
1885+
newWarpOp.getResult(newIdx));
1886+
// Remove the original `WarpOp` and `ForOp`, they should not have any uses
1887+
// at this point.
17991888
rewriter.eraseOp(forOp);
1800-
// Replace the warpOp result coming from the original ForOp.
1801-
for (const auto &res : llvm::enumerate(resultIdx)) {
1802-
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1803-
newForOp.getResult(res.index()));
1804-
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1805-
}
1889+
rewriter.eraseOp(warpOp);
1890+
// Update any users of escaping values that were forwarded to the
1891+
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
18061892
newForOp.walk([&](Operation *op) {
18071893
for (OpOperand &operand : op->getOpOperands()) {
18081894
auto it = argIndexMapping.find(operand.get());
@@ -1812,7 +1898,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18121898
}
18131899
});
18141900

1815-
// Finally, hoist out any now uniform code from the inner warp op.
1901+
// Finally, hoist out any now uniform code from the inner `WarpOp`.
18161902
mlir::vector::moveScalarUniformCode(innerWarp);
18171903
return success();
18181904
}

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,15 +876,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
876876
// Step 3: Apply subgroup to workitem distribution patterns.
877877
RewritePatternSet patterns(&getContext());
878878
xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
879-
// TODO: distributionFn and shuffleFn are not used at this point.
879+
// distributionFn is used by vector distribution patterns to determine the
880+
// distributed vector type for a given vector value. In XeGPU subgroup
881+
// distribution context, we compute this based on lane layout.
880882
auto distributionFn = [](Value val) {
881883
VectorType vecType = dyn_cast<VectorType>(val.getType());
882884
int64_t vecRank = vecType ? vecType.getRank() : 0;
883-
OpBuilder builder(val.getContext());
884885
if (vecRank == 0)
885886
return AffineMap::get(val.getContext());
886-
return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
887+
// Get the layout of the vector type.
888+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
889+
// If no layout is specified, assume the inner most dimension is distributed
890+
// for now.
891+
if (!layout)
892+
return AffineMap::getMultiDimMapWithTargets(
893+
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
894+
SmallVector<unsigned int> distributedDims;
895+
// Get the distributed dimensions based on the layout.
896+
ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
897+
for (unsigned i = 0; i < laneLayout.size(); ++i) {
898+
if (laneLayout[i] > 1)
899+
distributedDims.push_back(i);
900+
}
901+
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
902+
val.getContext());
887903
};
904+
// TODO: shuffleFn is not used.
888905
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
889906
int64_t warpSz) { return Value(); };
890907
vector::populatePropagateWarpVectorDistributionPatterns(

0 commit comments

Comments
 (0)