@@ -2038,11 +2038,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20382038 }
20392039
20402040 // Newly created `WarpOp` will yield values in following order:
2041- // 1. All init args of the `ForOp`.
2042- // 2. All escaping values.
2043- // 3. All non-`ForOp` yielded values.
2041+ // 1. Loop bounds.
2042+ // 2. All init args of the `ForOp`.
2043+ // 3. All escaping values.
2044+ // 4. All non-`ForOp` yielded values.
20442045 SmallVector<Value> newWarpOpYieldValues;
20452046 SmallVector<Type> newWarpOpDistTypes;
2047+ newWarpOpYieldValues.insert (
2048+ newWarpOpYieldValues.end (),
2049+ {forOp.getLowerBound (), forOp.getUpperBound (), forOp.getStep ()});
2050+ newWarpOpDistTypes.insert (newWarpOpDistTypes.end (),
2051+ {forOp.getLowerBound ().getType (),
2052+ forOp.getUpperBound ().getType (),
2053+ forOp.getStep ().getType ()});
20462054 for (auto [i, initArg] : llvm::enumerate (forOp.getInitArgs ())) {
20472055 newWarpOpYieldValues.push_back (initArg);
20482056 // Compute the distributed type for this init arg.
@@ -2081,20 +2089,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20812089
20822090 // Next, we create a new `ForOp` with the init args yielded by the new
20832091 // `WarpOp`.
2092+ const unsigned initArgsStartIdx = 3 ; // After loop bounds.
20842093 const unsigned escapingValuesStartIdx =
2094+ initArgsStartIdx +
20852095 forOp.getInitArgs ().size (); // `ForOp` init args are positioned before
20862096 // escaping values in the new `WarpOp`.
20872097 SmallVector<Value> newForOpOperands;
2088- for (size_t i = 0 ; i < escapingValuesStartIdx; ++i)
2098+ for (size_t i = initArgsStartIdx ; i < escapingValuesStartIdx; ++i)
20892099 newForOpOperands.push_back (newWarpOp.getResult (i));
20902100
20912101 // Create a new `ForOp` outside the new `WarpOp` region.
20922102 OpBuilder::InsertionGuard g (rewriter);
20932103 rewriter.setInsertionPointAfter (newWarpOp);
20942104 auto newForOp = scf::ForOp::create (
2095- rewriter, forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
2096- forOp.getStep (), newForOpOperands, /* bodyBuilder=*/ nullptr ,
2097- forOp.getUnsignedCmp ());
2105+ rewriter, forOp.getLoc (), /* *LowerBound=**/ newWarpOp.getResult (0 ),
2106+ /* *UpperBound=**/ newWarpOp.getResult (1 ),
2107+ /* *Step=**/ newWarpOp.getResult (2 ), newForOpOperands,
2108+ /* bodyBuilder=*/ nullptr , forOp.getUnsignedCmp ());
20982109 // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
20992110 // newly created `ForOp`. This `WarpOp` will contain all ops that were
21002111 // contained within the original `ForOp` body.
0 commit comments