@@ -371,6 +371,38 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
371
371
return targetType;
372
372
}
373
373
374
+ // / Given a warpOp that contains ops with regions, the corresponding op's
375
+ // / "inner" region and the distributionMapFn, get all values used by the op's
376
+ // / region that are defined within the warpOp, but outside the inner region.
377
+ // / Return the set of values, their types and their distributed types.
378
+ std::tuple<llvm::SmallSetVector<Value, 32 >, SmallVector<Type>,
379
+ SmallVector<Type>>
380
+ getInnerRegionEscapingValues (WarpExecuteOnLane0Op warpOp, Region &innerRegion,
381
+ DistributionMapFn distributionMapFn) {
382
+ llvm::SmallSetVector<Value, 32 > escapingValues;
383
+ SmallVector<Type> escapingValueTypes;
384
+ SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
385
+ if (innerRegion.empty ())
386
+ return {std::move (escapingValues), std::move (escapingValueTypes),
387
+ std::move (escapingValueDistTypes)};
388
+ mlir::visitUsedValuesDefinedAbove (innerRegion, [&](OpOperand *operand) {
389
+ Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
390
+ if (warpOp->isAncestor (parent)) {
391
+ if (!escapingValues.insert (operand->get ()))
392
+ return ;
393
+ Type distType = operand->get ().getType ();
394
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
395
+ AffineMap map = distributionMapFn (operand->get ());
396
+ distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
397
+ }
398
+ escapingValueTypes.push_back (operand->get ().getType ());
399
+ escapingValueDistTypes.push_back (distType);
400
+ }
401
+ });
402
+ return {std::move (escapingValues), std::move (escapingValueTypes),
403
+ std::move (escapingValueDistTypes)};
404
+ }
405
+
374
406
// / Distribute transfer_write ops based on the affine map returned by
375
407
// / `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
376
408
// / will not be distributed (it should be less than the warp size).
@@ -1713,6 +1745,215 @@ struct WarpOpInsert : public WarpDistributionPattern {
1713
1745
}
1714
1746
};
1715
1747
1748
+ // / Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1749
+ // / the scf.if is the last operation in the region so that it doesn't
1750
+ // / change the order of execution. This creates a new scf.if after the
1751
+ // / WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1752
+ // / the "inner" WarpExecuteOnLane0Op. Example:
1753
+ // / ```
1754
+ // / gpu.warp_execute_on_lane_0(%laneid)[32] {
1755
+ // / %payload = ... : vector<32xindex>
1756
+ // / scf.if %pred {
1757
+ // / vector.store %payload, %buffer[%idx] : memref<128xindex>,
1758
+ // / vector<32xindex>
1759
+ // / }
1760
+ // / gpu.yield
1761
+ // / }
1762
+ // / ```
1763
+ // / %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1764
+ // / %payload = ... : vector<32xindex>
1765
+ // / gpu.yield %payload : vector<32xindex>
1766
+ // / }
1767
+ // / scf.if %pred {
1768
+ // / gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1769
+ // / ^bb0(%arg1: vector<32xindex>):
1770
+ // / vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1771
+ // / }
1772
+ // / }
1773
+ // / ```
1774
+ struct WarpOpScfIfOp : public WarpDistributionPattern {
1775
+ WarpOpScfIfOp (MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1 )
1776
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1777
+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1778
+ PatternRewriter &rewriter) const override {
1779
+ gpu::YieldOp warpOpYield = warpOp.getTerminator ();
1780
+ // Only pick up `IfOp` if it is the last op in the region.
1781
+ Operation *lastNode = warpOpYield->getPrevNode ();
1782
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1783
+ if (!ifOp)
1784
+ return failure ();
1785
+
1786
+ // The current `WarpOp` can yield two types of values:
1787
+ // 1. Not results of `IfOp`:
1788
+ // Preserve them in the new `WarpOp`.
1789
+ // Collect their yield index to remap the usages.
1790
+ // 2. Results of `IfOp`:
1791
+ // They are not part of the new `WarpOp` results.
1792
+ // Map current warp's yield operand index to `IfOp` result idx.
1793
+ SmallVector<Value> nonIfYieldValues;
1794
+ SmallVector<unsigned > nonIfYieldIndices;
1795
+ llvm::SmallDenseMap<unsigned , unsigned > ifResultMapping;
1796
+ llvm::SmallDenseMap<unsigned , VectorType> ifResultDistTypes;
1797
+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands ()) {
1798
+ const unsigned yieldOperandIdx = yieldOperand.getOperandNumber ();
1799
+ if (yieldOperand.get ().getDefiningOp () != ifOp.getOperation ()) {
1800
+ nonIfYieldValues.push_back (yieldOperand.get ());
1801
+ nonIfYieldIndices.push_back (yieldOperandIdx);
1802
+ continue ;
1803
+ }
1804
+ OpResult ifResult = cast<OpResult>(yieldOperand.get ());
1805
+ const unsigned ifResultIdx = ifResult.getResultNumber ();
1806
+ ifResultMapping[yieldOperandIdx] = ifResultIdx;
1807
+ // If this `ifOp` result is vector type and it is yielded by the
1808
+ // `WarpOp`, we keep track the distributed type for this result.
1809
+ if (!isa<VectorType>(ifResult.getType ()))
1810
+ continue ;
1811
+ VectorType distType =
1812
+ cast<VectorType>(warpOp.getResult (yieldOperandIdx).getType ());
1813
+ ifResultDistTypes[ifResultIdx] = distType;
1814
+ }
1815
+
1816
+ // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1817
+ // them
1818
+ auto [escapingValuesThen, escapingValueInputTypesThen,
1819
+ escapingValueDistTypesThen] =
1820
+ getInnerRegionEscapingValues (warpOp, ifOp.getThenRegion (),
1821
+ distributionMapFn);
1822
+ auto [escapingValuesElse, escapingValueInputTypesElse,
1823
+ escapingValueDistTypesElse] =
1824
+ getInnerRegionEscapingValues (warpOp, ifOp.getElseRegion (),
1825
+ distributionMapFn);
1826
+ if (llvm::is_contained (escapingValueDistTypesThen, Type{}) ||
1827
+ llvm::is_contained (escapingValueDistTypesElse, Type{}))
1828
+ return failure ();
1829
+
1830
+ // The new `WarpOp` groups yields values in following order:
1831
+ // 1. Branch condition
1832
+ // 2. Escaping values then branch
1833
+ // 3. Escaping values else branch
1834
+ // 4. All non-`ifOp` yielded values.
1835
+ SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition ()};
1836
+ newWarpOpYieldValues.append (escapingValuesThen.begin (),
1837
+ escapingValuesThen.end ());
1838
+ newWarpOpYieldValues.append (escapingValuesElse.begin (),
1839
+ escapingValuesElse.end ());
1840
+ SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition ().getType ()};
1841
+ newWarpOpDistTypes.append (escapingValueDistTypesThen.begin (),
1842
+ escapingValueDistTypesThen.end ());
1843
+ newWarpOpDistTypes.append (escapingValueDistTypesElse.begin (),
1844
+ escapingValueDistTypesElse.end ());
1845
+
1846
+ llvm::SmallDenseMap<unsigned , unsigned > origToNewYieldIdx;
1847
+ for (auto [idx, val] :
1848
+ llvm::zip_equal (nonIfYieldIndices, nonIfYieldValues)) {
1849
+ origToNewYieldIdx[idx] = newWarpOpYieldValues.size ();
1850
+ newWarpOpYieldValues.push_back (val);
1851
+ newWarpOpDistTypes.push_back (warpOp.getResult (idx).getType ());
1852
+ }
1853
+ // Create the new `WarpOp` with the updated yield values and types.
1854
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1855
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1856
+ // `ifOp` returns the result of the inner warp op.
1857
+ SmallVector<Type> newIfOpDistResTypes;
1858
+ for (auto [i, res] : llvm::enumerate (ifOp.getResults ())) {
1859
+ Type distType = cast<Value>(res).getType ();
1860
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
1861
+ AffineMap map = distributionMapFn (cast<Value>(res));
1862
+ // Fallback to affine map if the dist result was not previously recorded
1863
+ distType = ifResultDistTypes.count (i)
1864
+ ? ifResultDistTypes[i]
1865
+ : getDistributedType (vecType, map, warpOp.getWarpSize ());
1866
+ }
1867
+ newIfOpDistResTypes.push_back (distType);
1868
+ }
1869
+ // Create a new `IfOp` outside the new `WarpOp` region.
1870
+ OpBuilder::InsertionGuard g (rewriter);
1871
+ rewriter.setInsertionPointAfter (newWarpOp);
1872
+ auto newIfOp = scf::IfOp::create (
1873
+ rewriter, ifOp.getLoc (), newIfOpDistResTypes, newWarpOp.getResult (0 ),
1874
+ static_cast <bool >(ifOp.thenBlock ()),
1875
+ static_cast <bool >(ifOp.elseBlock ()));
1876
+ auto encloseRegionInWarpOp =
1877
+ [&](Block *oldIfBranch, Block *newIfBranch,
1878
+ llvm::SmallSetVector<Value, 32 > &escapingValues,
1879
+ SmallVector<Type> &escapingValueInputTypes,
1880
+ size_t warpResRangeStart) {
1881
+ OpBuilder::InsertionGuard g (rewriter);
1882
+ if (!newIfBranch)
1883
+ return ;
1884
+ rewriter.setInsertionPointToStart (newIfBranch);
1885
+ llvm::SmallDenseMap<Value, int64_t > escapeValToBlockArgIndex;
1886
+ SmallVector<Value> innerWarpInputVals;
1887
+ SmallVector<Type> innerWarpInputTypes;
1888
+ for (size_t i = 0 ; i < escapingValues.size ();
1889
+ ++i, ++warpResRangeStart) {
1890
+ innerWarpInputVals.push_back (
1891
+ newWarpOp.getResult (warpResRangeStart));
1892
+ escapeValToBlockArgIndex[escapingValues[i]] =
1893
+ innerWarpInputTypes.size ();
1894
+ innerWarpInputTypes.push_back (escapingValueInputTypes[i]);
1895
+ }
1896
+ auto innerWarp = WarpExecuteOnLane0Op::create (
1897
+ rewriter, newWarpOp.getLoc (), newIfOp.getResultTypes (),
1898
+ newWarpOp.getLaneid (), newWarpOp.getWarpSize (),
1899
+ innerWarpInputVals, innerWarpInputTypes);
1900
+
1901
+ innerWarp.getWarpRegion ().takeBody (*oldIfBranch->getParent ());
1902
+ innerWarp.getWarpRegion ().addArguments (
1903
+ innerWarpInputTypes,
1904
+ SmallVector<Location>(innerWarpInputTypes.size (), ifOp.getLoc ()));
1905
+
1906
+ SmallVector<Value> yieldOperands;
1907
+ for (Value operand : oldIfBranch->getTerminator ()->getOperands ())
1908
+ yieldOperands.push_back (operand);
1909
+ rewriter.eraseOp (oldIfBranch->getTerminator ());
1910
+
1911
+ rewriter.setInsertionPointToEnd (innerWarp.getBody ());
1912
+ gpu::YieldOp::create (rewriter, innerWarp.getLoc (), yieldOperands);
1913
+ rewriter.setInsertionPointAfter (innerWarp);
1914
+ scf::YieldOp::create (rewriter, ifOp.getLoc (), innerWarp.getResults ());
1915
+
1916
+ // Update any users of escaping values that were forwarded to the
1917
+ // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1918
+ innerWarp.walk ([&](Operation *op) {
1919
+ for (OpOperand &operand : op->getOpOperands ()) {
1920
+ auto it = escapeValToBlockArgIndex.find (operand.get ());
1921
+ if (it == escapeValToBlockArgIndex.end ())
1922
+ continue ;
1923
+ operand.set (innerWarp.getBodyRegion ().getArgument (it->second ));
1924
+ }
1925
+ });
1926
+ mlir::vector::moveScalarUniformCode (innerWarp);
1927
+ };
1928
+ encloseRegionInWarpOp (&ifOp.getThenRegion ().front (),
1929
+ &newIfOp.getThenRegion ().front (), escapingValuesThen,
1930
+ escapingValueInputTypesThen, 1 );
1931
+ if (!ifOp.getElseRegion ().empty ())
1932
+ encloseRegionInWarpOp (&ifOp.getElseRegion ().front (),
1933
+ &newIfOp.getElseRegion ().front (),
1934
+ escapingValuesElse, escapingValueInputTypesElse,
1935
+ 1 + escapingValuesThen.size ());
1936
+ // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1937
+ // result.
1938
+ for (auto [origIdx, newIdx] : ifResultMapping)
1939
+ rewriter.replaceAllUsesExcept (warpOp.getResult (origIdx),
1940
+ newIfOp.getResult (newIdx), newIfOp);
1941
+ // Similarly, update any users of the `WarpOp` results that were not
1942
+ // results of the `IfOp`.
1943
+ for (auto [origIdx, newIdx] : origToNewYieldIdx)
1944
+ rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
1945
+ newWarpOp.getResult (newIdx));
1946
+ // Remove the original `WarpOp` and `IfOp`, they should not have any uses
1947
+ // at this point.
1948
+ rewriter.eraseOp (ifOp);
1949
+ rewriter.eraseOp (warpOp);
1950
+ return success ();
1951
+ }
1952
+
1953
+ private:
1954
+ DistributionMapFn distributionMapFn;
1955
+ };
1956
+
1716
1957
// / Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1717
1958
// / the scf.ForOp is the last operation in the region so that it doesn't
1718
1959
// / change the order of execution. This creates a new scf.for region after the
@@ -1759,25 +2000,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
1759
2000
return failure ();
1760
2001
// Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1761
2002
// Those Values need to be returned by the new warp op.
1762
- llvm::SmallSetVector<Value, 32 > escapingValues;
1763
- SmallVector<Type> escapingValueInputTypes;
1764
- SmallVector<Type> escapingValueDistTypes;
1765
- mlir::visitUsedValuesDefinedAbove (
1766
- forOp.getBodyRegion (), [&](OpOperand *operand) {
1767
- Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
1768
- if (warpOp->isAncestor (parent)) {
1769
- if (!escapingValues.insert (operand->get ()))
1770
- return ;
1771
- Type distType = operand->get ().getType ();
1772
- if (auto vecType = dyn_cast<VectorType>(distType)) {
1773
- AffineMap map = distributionMapFn (operand->get ());
1774
- distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1775
- }
1776
- escapingValueInputTypes.push_back (operand->get ().getType ());
1777
- escapingValueDistTypes.push_back (distType);
1778
- }
1779
- });
1780
-
2003
+ auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2004
+ getInnerRegionEscapingValues (warpOp, forOp.getBodyRegion (),
2005
+ distributionMapFn);
1781
2006
if (llvm::is_contained (escapingValueDistTypes, Type{}))
1782
2007
return failure ();
1783
2008
// `WarpOp` can yield two types of values:
@@ -2068,6 +2293,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2068
2293
benefit);
2069
2294
patterns.add <WarpOpScfForOp>(patterns.getContext (), distributionMapFn,
2070
2295
benefit);
2296
+ patterns.add <WarpOpScfIfOp>(patterns.getContext (), distributionMapFn,
2297
+ benefit);
2071
2298
}
2072
2299
2073
2300
void mlir::vector::populateDistributeReduction (
0 commit comments