@@ -1794,14 +1794,18 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
17941794 return failure ();
17951795
17961796 // The new `WarpOp` groups yields values in following order:
1797- // 1. Escaping values then branch
1798- // 2. Escaping values else branch
1799- // 3. All non-`ifOp` yielded values.
1800- SmallVector<Value> newWarpOpYieldValues{escapingValuesThen.begin (),
1801- escapingValuesThen.end ()};
1797+ // 1. Branch condition
1798+ // 2. Escaping values then branch
1799+ // 3. Escaping values else branch
1800+ // 4. All non-`ifOp` yielded values.
1801+ SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition ()};
1802+ newWarpOpYieldValues.append (escapingValuesThen.begin (),
1803+ escapingValuesThen.end ());
18021804 newWarpOpYieldValues.append (escapingValuesElse.begin (),
18031805 escapingValuesElse.end ());
1804- SmallVector<Type> newWarpOpDistTypes = escapingValueDistTypesThen;
1806+ SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition ().getType ()};
1807+ newWarpOpDistTypes.append (escapingValueDistTypesThen.begin (),
1808+ escapingValueDistTypesThen.end ());
18051809 newWarpOpDistTypes.append (escapingValueDistTypesElse.begin (),
18061810 escapingValueDistTypesElse.end ());
18071811
@@ -1815,7 +1819,6 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18151819 // Create the new `WarpOp` with the updated yield values and types.
18161820 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
18171821 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1818-
18191822 // `ifOp` returns the result of the inner warp op.
18201823 SmallVector<Type> newIfOpDistResTypes;
18211824 for (auto [i, res] : llvm::enumerate (ifOp.getResults ())) {
@@ -1831,23 +1834,24 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18311834 // Create a new `IfOp` outside the new `WarpOp` region.
18321835 OpBuilder::InsertionGuard g (rewriter);
18331836 rewriter.setInsertionPointAfter (newWarpOp);
1834- auto newIfOp = scf::IfOp::create (rewriter, ifOp. getLoc (),
1835- newIfOpDistResTypes, ifOp. getCondition ( ),
1836- static_cast <bool >(ifOp.thenBlock ()),
1837- static_cast <bool >(ifOp.elseBlock ()));
1837+ auto newIfOp = scf::IfOp::create (
1838+ rewriter, ifOp. getLoc (), newIfOpDistResTypes, newWarpOp. getResult ( 0 ),
1839+ static_cast <bool >(ifOp.thenBlock ()),
1840+ static_cast <bool >(ifOp.elseBlock ()));
18381841
18391842 auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
18401843 llvm::SmallSetVector<Value, 32 > &escapingValues,
1841- SmallVector<Type> &escapingValueInputTypes) {
1844+ SmallVector<Type> &escapingValueInputTypes,
1845+ size_t warpResRangeStart) {
18421846 OpBuilder::InsertionGuard g (rewriter);
18431847 if (!newIfBranch)
18441848 return ;
18451849 rewriter.setInsertionPointToStart (newIfBranch);
18461850 llvm::SmallDenseMap<Value, int64_t > escapeValToBlockArgIndex;
18471851 SmallVector<Value> innerWarpInputVals;
18481852 SmallVector<Type> innerWarpInputTypes;
1849- for (size_t i = 0 ; i < escapingValues.size (); ++i) {
1850- innerWarpInputVals.push_back (newWarpOp.getResult (i ));
1853+ for (size_t i = 0 ; i < escapingValues.size (); ++i, ++warpResRangeStart ) {
1854+ innerWarpInputVals.push_back (newWarpOp.getResult (warpResRangeStart ));
18511855 escapeValToBlockArgIndex[escapingValues[i]] =
18521856 innerWarpInputTypes.size ();
18531857 innerWarpInputTypes.push_back (escapingValueInputTypes[i]);
@@ -1886,11 +1890,11 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18861890 };
18871891 processBranch (&ifOp.getThenRegion ().front (),
18881892 &newIfOp.getThenRegion ().front (), escapingValuesThen,
1889- escapingValueInputTypesThen);
1893+ escapingValueInputTypesThen, 1 );
18901894 if (!ifOp.getElseRegion ().empty ())
18911895 processBranch (&ifOp.getElseRegion ().front (),
18921896 &newIfOp.getElseRegion ().front (), escapingValuesElse,
1893- escapingValueInputTypesElse);
1897+ escapingValueInputTypesElse, 1 + escapingValuesThen. size () );
18941898 // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
18951899 // result.
18961900 for (auto [origIdx, newIdx] : ifResultMapping)
0 commit comments