@@ -976,8 +976,7 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
976976 for (auto &[op, slices] : userOpToSlicesMap) {
977977 bool avoidBatching =
978978 llvm::TypeSwitch<Operation *, bool >(op)
979- .Case <stablehlo::DynamicSliceOp, stablehlo::ReshapeOp,
980- stablehlo::SliceOp,
979+ .Case <stablehlo::ReshapeOp, stablehlo::SliceOp,
981980 // TODO: avoid scatter since that lowers to loop right now
982981 stablehlo::ScatterOp>([=](auto op) { return true ; })
983982 .Case <stablehlo::BroadcastInDimOp, stablehlo::TransposeOp>(
@@ -987,9 +986,13 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
987986 continue ;
988987 }
989988
990- if ((dyn_cast<BatchOpInterface>(op) ||
991- stablehlo::hasTraitElementwise (op)) &&
992- op->getNumResults () == 1 ) {
989+ if (auto dsOp = dyn_cast<stablehlo::DynamicSliceOp>(op)) {
990+ if (raiseDynamicSliceToGather (rewriter, whileOp, slices, dsOp, info)) {
991+ anyOpRewritten = true ;
992+ }
993+ } else if ((dyn_cast<BatchOpInterface>(op) ||
994+ stablehlo::hasTraitElementwise (op)) &&
995+ op->getNumResults () == 1 ) {
993996 if (liftOperationByBatching (rewriter, whileOp, slices, op, info)) {
994997 anyOpRewritten = true ;
995998 } else if (liftReduceLikeOperation (rewriter, whileOp, slices, op, info)) {
@@ -1440,6 +1443,247 @@ bool liftReduceLikeOperation(
14401443 return true ;
14411444}
14421445
1446+ bool raiseDynamicSliceToGather (
1447+ PatternRewriter &rewriter, stablehlo::WhileOp whileOp,
1448+ ArrayRef<SliceInfo<stablehlo::DynamicSliceOp>> slices,
1449+ stablehlo::DynamicSliceOp dsOp, WhileLoopInfo info) {
1450+ // Pattern: x[..., y[idx], z[idx], ...] where idx is an affine function of
1451+ // the loop induction variable. We need to:
1452+ // 1. Hoist the computation of y[idx], z[idx], etc. for all loop iterations
1453+ // 2. Create a gather operation from x using those indices
1454+ // 3. Replace dsOp uses with a dynamic slice into the gather result
1455+
1456+ // Find all start indices that are dependent on the loop and come from inner
1457+ // dynamic slices (through possible reshape)
1458+ SmallVector<int64_t > dependentDims;
1459+ SmallVector<Value> innerSliceOperands;
1460+ SmallVector<SliceInfo<stablehlo::DynamicSliceOp>> innerSliceInfos;
1461+
1462+ for (auto [i, startIndex] : llvm::enumerate (dsOp.getStartIndices ())) {
1463+ // Use traverseOperandsForHoisting to classify this operand
1464+ SmallVector<BatchLiftingMode> modes;
1465+ SmallVector<Value> operands;
1466+ SmallVector<SmallVector<int64_t >> dims;
1467+ SmallVector<int64_t > hoisted;
1468+ SmallVector<SliceInfo<stablehlo::DynamicSliceOp>> mapped;
1469+ DenseMap<Value, SmallVector<Operation *>> hoistMap;
1470+
1471+ SmallVector<Value> singleOperand = {startIndex};
1472+ if (!traverseOperandsForHoisting (singleOperand, whileOp, slices, info,
1473+ modes, operands, dims, hoisted, mapped,
1474+ hoistMap)) {
1475+ return false ;
1476+ }
1477+
1478+ if (modes[0 ] == BatchLiftingMode::DYNAMIC_SLICE) {
1479+ dependentDims.push_back (i);
1480+ innerSliceOperands.push_back (operands[0 ]);
1481+ innerSliceInfos.push_back (mapped[0 ]);
1482+ }
1483+ }
1484+
1485+ if (dependentDims.empty ()) {
1486+ return false ;
1487+ }
1488+
1489+ // Get outer operand - it must be constant across iterations (loop invariant)
1490+ Value outerOperand;
1491+ Value dsOperand = dsOp.getOperand ();
1492+ SmallVector<Operation *> canBeHoisted;
1493+ if (!info.isConstantAcrossIterations (dsOperand, outerOperand, canBeHoisted,
1494+ true )) {
1495+ return false ;
1496+ }
1497+ if (!outerOperand) {
1498+ // The operand is defined inside the loop but is hoistable - hoist it
1499+ DenseMap<Value, SmallVector<Operation *>> hoistMap;
1500+ hoistMap[dsOperand] = canBeHoisted;
1501+ DenseMap<Value, Value> hoistedValues;
1502+ hoistChainOfOps (hoistMap, rewriter, whileOp, info, hoistedValues);
1503+ outerOperand = hoistedValues[dsOperand];
1504+ }
1505+
1506+ // Verify all non-dependent start indices are constant across iterations
1507+ for (auto [i, startIndex] : llvm::enumerate (dsOp.getStartIndices ())) {
1508+ if (llvm::is_contained (dependentDims, i)) {
1509+ continue ;
1510+ }
1511+ if (!info.isConstantAcrossIterations (startIndex, true )) {
1512+ return false ;
1513+ }
1514+ }
1515+
1516+ int64_t numIters = info.getConstantNumIters ();
1517+ Location loc = dsOp.getLoc ();
1518+
1519+ rewriter.setInsertionPoint (whileOp);
1520+
1521+ // Step 1: Hoist each inner slice operand and construct the gather indices
1522+ // We need to gather all indices for all loop iterations and concatenate them.
1523+ SmallVector<Value> hoistedIndicesList;
1524+ Type hoistedIndicesElemTy;
1525+
1526+ for (size_t idx = 0 ; idx < dependentDims.size (); idx++) {
1527+ Value hoistedIndices;
1528+ if (!info.hoistOperationFromLoop (
1529+ rewriter, innerSliceOperands[idx], innerSliceInfos[idx].sliceOp ,
1530+ innerSliceInfos[idx].dimensions , hoistedIndices)) {
1531+ return false ;
1532+ }
1533+
1534+ auto hoistedTy = cast<RankedTensorType>(hoistedIndices.getType ());
1535+ if (idx == 0 ) {
1536+ hoistedIndicesElemTy = hoistedTy.getElementType ();
1537+ }
1538+
1539+ // Reshape to [numIters, 1] for use as gather indices
1540+ SmallVector<int64_t > reshapeShape = {numIters, 1 };
1541+ auto reshapeTy = RankedTensorType::get (reshapeShape, hoistedIndicesElemTy);
1542+
1543+ // Convert type if needed
1544+ if (hoistedTy.getElementType () != hoistedIndicesElemTy) {
1545+ hoistedIndices = stablehlo::ConvertOp::create (
1546+ rewriter, loc,
1547+ RankedTensorType::get (hoistedTy.getShape (), hoistedIndicesElemTy),
1548+ hoistedIndices);
1549+ }
1550+
1551+ Value reshaped =
1552+ stablehlo::ReshapeOp::create (rewriter, loc, reshapeTy, hoistedIndices);
1553+ hoistedIndicesList.push_back (reshaped);
1554+ }
1555+
1556+ // Concatenate all hoisted indices along the last dimension
1557+ Value gatherIndices;
1558+ if (hoistedIndicesList.size () == 1 ) {
1559+ gatherIndices = hoistedIndicesList[0 ];
1560+ } else {
1561+ // Result shape: [numIters, numDependentDims]
1562+ SmallVector<int64_t > concatShape = {numIters,
1563+ (int64_t )dependentDims.size ()};
1564+ auto concatTy = RankedTensorType::get (concatShape, hoistedIndicesElemTy);
1565+ gatherIndices = stablehlo::ConcatenateOp::create (
1566+ rewriter, loc, concatTy, hoistedIndicesList, /* dimension=*/ 1 );
1567+ }
1568+
1569+ // Step 2: Create the gather operation from the outer operand
1570+ auto outerOperandTy = cast<RankedTensorType>(outerOperand.getType ());
1571+ auto dsSliceSizes = dsOp.getSliceSizes ();
1572+
1573+ // The gather slice sizes: dependent dimensions get 1, others get original
1574+ SmallVector<int64_t > gatherSliceSizes;
1575+ for (size_t i = 0 ; i < dsSliceSizes.size (); i++) {
1576+ if (llvm::is_contained (dependentDims, i)) {
1577+ gatherSliceSizes.push_back (1 );
1578+ } else {
1579+ gatherSliceSizes.push_back (dsSliceSizes[i]);
1580+ }
1581+ }
1582+
1583+ // offsetDims: output dimensions corresponding to non-collapsed slice dims
1584+ // Start at 1 since batch dim is at position 0, then consecutive for each
1585+ // non-collapsed dimension
1586+ SmallVector<int64_t > offsetDims;
1587+ int64_t offsetDimIdx = 1 ; // Start after the batch dimension
1588+ for (size_t i = 0 ; i < outerOperandTy.getRank (); i++) {
1589+ if (!llvm::is_contained (dependentDims, i)) {
1590+ offsetDims.push_back (offsetDimIdx);
1591+ offsetDimIdx++;
1592+ }
1593+ }
1594+
1595+ // collapsedSliceDims: the dimensions we're indexing into
1596+ SmallVector<int64_t > collapsedSliceDims = llvm::to_vector (dependentDims);
1597+
1598+ // startIndexMap: maps index vector dimensions to operand dimensions
1599+ SmallVector<int64_t > startIndexMap = llvm::to_vector (dependentDims);
1600+
1601+ // Calculate output shape: [numIters, ...sliceSizes for non-dependent dims...]
1602+ SmallVector<int64_t > gatherOutputShape;
1603+ gatherOutputShape.push_back (numIters);
1604+ for (size_t i = 0 ; i < dsSliceSizes.size (); i++) {
1605+ if (!llvm::is_contained (dependentDims, i)) {
1606+ gatherOutputShape.push_back (dsSliceSizes[i]);
1607+ }
1608+ }
1609+
1610+ auto gatherResultTy =
1611+ RankedTensorType::get (gatherOutputShape, outerOperandTy.getElementType ());
1612+
1613+ auto gatherOp = stablehlo::GatherOp::create (
1614+ rewriter, loc, gatherResultTy, outerOperand, gatherIndices,
1615+ stablehlo::GatherDimensionNumbersAttr::get (
1616+ rewriter.getContext (),
1617+ /* offsetDims=*/ offsetDims,
1618+ /* collapsedSliceDims=*/ collapsedSliceDims,
1619+ /* operandBatchingDims=*/ {},
1620+ /* startIndicesBatchingDims=*/ {},
1621+ /* startIndexMap=*/ startIndexMap,
1622+ /* indexVectorDim=*/ 1 ),
1623+ gatherSliceSizes);
1624+
1625+ // Step 3: Replace the dsOp with a dynamic slice into the gather result
1626+ // The dynamic slice will index using the loop induction variable
1627+ rewriter.setInsertionPointAfter (dsOp);
1628+
1629+ auto inductionVar = info.getInductionVariable ();
1630+ auto inductionVarType = cast<RankedTensorType>(inductionVar.getType ());
1631+
1632+ // Compute the index for the dynamic slice
1633+ Value sliceIndex;
1634+ if (info.isConstantStart () && info.getConstantStart () == 0 ) {
1635+ sliceIndex = inductionVar;
1636+ } else {
1637+ sliceIndex = stablehlo::SubtractOp::create (rewriter, loc, inductionVar,
1638+ info.getStart ());
1639+ }
1640+ if (!info.isStepOne ()) {
1641+ sliceIndex = stablehlo::DivOp::create (rewriter, loc, sliceIndex,
1642+ info.getStep (rewriter));
1643+ }
1644+
1645+ // Convert sliceIndex to the same type as gather indices if needed
1646+ if (inductionVarType.getElementType () != hoistedIndicesElemTy) {
1647+ auto newIndexTy = RankedTensorType::get ({}, hoistedIndicesElemTy);
1648+ sliceIndex =
1649+ stablehlo::ConvertOp::create (rewriter, loc, newIndexTy, sliceIndex);
1650+ }
1651+
1652+ // Create constZero with the same type as sliceIndex (after conversion)
1653+ auto sliceIndexTy = cast<RankedTensorType>(sliceIndex.getType ());
1654+ auto constZero = stablehlo::ConstantOp::create (
1655+ rewriter, loc, sliceIndexTy,
1656+ cast<ElementsAttr>(makeAttr (sliceIndexTy, 0 )));
1657+ // Build the start indices for dynamic slice
1658+ SmallVector<Value> dynSliceStarts;
1659+ dynSliceStarts.push_back (sliceIndex);
1660+ for (size_t i = 0 ; i < dsSliceSizes.size (); i++) {
1661+ if (!llvm::is_contained (dependentDims, i)) {
1662+ dynSliceStarts.push_back (constZero);
1663+ }
1664+ }
1665+
1666+ // Build the slice sizes (1 for the batch dim, original sizes for others)
1667+ SmallVector<int64_t > dynSliceSizes;
1668+ dynSliceSizes.push_back (1 );
1669+ for (size_t i = 0 ; i < dsSliceSizes.size (); i++) {
1670+ if (!llvm::is_contained (dependentDims, i)) {
1671+ dynSliceSizes.push_back (dsSliceSizes[i]);
1672+ }
1673+ }
1674+
1675+ auto dynSlice = stablehlo::DynamicSliceOp::create (
1676+ rewriter, loc, gatherOp.getResult (), dynSliceStarts, dynSliceSizes);
1677+
1678+ // Reshape to match the original dsOp output type
1679+ auto replacement =
1680+ stablehlo::ReshapeOp::create (rewriter, loc, dsOp.getType (), dynSlice);
1681+
1682+ rewriter.replaceOp (dsOp, replacement.getResult ());
1683+
1684+ return true ;
1685+ }
1686+
14431687bool liftOperationByBatching (
14441688 PatternRewriter &rewriter, stablehlo::WhileOp whileOp,
14451689 ArrayRef<SliceInfo<stablehlo::DynamicSliceOp>> slices, Operation *op,
0 commit comments