Skip to content

Commit 14bea1f

Browse files
authored
Merge branch 'main' into jm/multirotate_recognize
2 parents 48557ca + 19a4665 commit 14bea1f

File tree

8 files changed

+805
-6
lines changed

8 files changed

+805
-6
lines changed

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3939,6 +3939,42 @@ struct SHLOReverseOpBatchInterface
39393939
}
39403940
};
39413941

3942+
struct SHLOPadOpBatchInterface
3943+
: public BatchOpInterface::ExternalModel<SHLOPadOpBatchInterface,
3944+
stablehlo::PadOp> {
3945+
mlir::LogicalResult createBatch(Operation *src, OpBuilder &builder,
3946+
IRMapping &mapper,
3947+
ArrayRef<int64_t> batchSizes) const {
3948+
auto op = cast<stablehlo::PadOp>(src);
3949+
3950+
auto batchedPadValue = mapper.lookup(op.getPaddingValue());
3951+
auto scalarPadValue =
3952+
getScalarValue(batchedPadValue.getDefiningOp(), builder);
3953+
if (!scalarPadValue) {
3954+
return genericCreateBatch(src, builder, mapper, batchSizes);
3955+
}
3956+
3957+
int64_t nBatches = batchSizes.size();
3958+
SmallVector<int64_t> newLow(nBatches, 0);
3959+
newLow.append(op.getEdgePaddingLow().begin(), op.getEdgePaddingLow().end());
3960+
SmallVector<int64_t> newHigh(nBatches, 0);
3961+
newHigh.append(op.getEdgePaddingHigh().begin(),
3962+
op.getEdgePaddingHigh().end());
3963+
SmallVector<int64_t> newInterior(nBatches, 0);
3964+
newInterior.append(op.getInteriorPadding().begin(),
3965+
op.getInteriorPadding().end());
3966+
3967+
auto newPadOp = stablehlo::PadOp::create(
3968+
builder, op.getLoc(), mapper.lookup(op.getOperand()), scalarPadValue,
3969+
builder.getDenseI64ArrayAttr(newLow),
3970+
builder.getDenseI64ArrayAttr(newHigh),
3971+
builder.getDenseI64ArrayAttr(newInterior));
3972+
3973+
mapper.map(src->getResult(0), newPadOp.getResult());
3974+
return success();
3975+
}
3976+
};
3977+
39423978
// https://github.com/jax-ml/jax/blob/2a8cb54b82f1b0d17181d43f9be78d2b349df333/jax/_src/lax/convolution.py#L613-L629
39433979
struct SHLOConvolutionOpBatchInterface
39443980
: public BatchOpInterface::ExternalModel<SHLOConvolutionOpBatchInterface,
@@ -4110,6 +4146,7 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
41104146
*context);
41114147
ReverseOp::attachInterface<SHLOReverseOpBatchInterface>(*context);
41124148
ConvolutionOp::attachInterface<SHLOConvolutionOpBatchInterface>(*context);
4149+
PadOp::attachInterface<SHLOPadOpBatchInterface>(*context);
41134150

41144151
ScatterOp::attachInterface<SHLOGenericBatchOpInterface<ScatterOp>>(
41154152
*context); // TODO: simpler version with newly named dims

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 249 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -976,8 +976,7 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
976976
for (auto &[op, slices] : userOpToSlicesMap) {
977977
bool avoidBatching =
978978
llvm::TypeSwitch<Operation *, bool>(op)
979-
.Case<stablehlo::DynamicSliceOp, stablehlo::ReshapeOp,
980-
stablehlo::SliceOp,
979+
.Case<stablehlo::ReshapeOp, stablehlo::SliceOp,
981980
// TODO: avoid scatter since that lowers to loop right now
982981
stablehlo::ScatterOp>([=](auto op) { return true; })
983982
.Case<stablehlo::BroadcastInDimOp, stablehlo::TransposeOp>(
@@ -987,9 +986,13 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
987986
continue;
988987
}
989988

990-
if ((dyn_cast<BatchOpInterface>(op) ||
991-
stablehlo::hasTraitElementwise(op)) &&
992-
op->getNumResults() == 1) {
989+
if (auto dsOp = dyn_cast<stablehlo::DynamicSliceOp>(op)) {
990+
if (raiseDynamicSliceToGather(rewriter, whileOp, slices, dsOp, info)) {
991+
anyOpRewritten = true;
992+
}
993+
} else if ((dyn_cast<BatchOpInterface>(op) ||
994+
stablehlo::hasTraitElementwise(op)) &&
995+
op->getNumResults() == 1) {
993996
if (liftOperationByBatching(rewriter, whileOp, slices, op, info)) {
994997
anyOpRewritten = true;
995998
} else if (liftReduceLikeOperation(rewriter, whileOp, slices, op, info)) {
@@ -1440,6 +1443,247 @@ bool liftReduceLikeOperation(
14401443
return true;
14411444
}
14421445

1446+
bool raiseDynamicSliceToGather(
1447+
PatternRewriter &rewriter, stablehlo::WhileOp whileOp,
1448+
ArrayRef<SliceInfo<stablehlo::DynamicSliceOp>> slices,
1449+
stablehlo::DynamicSliceOp dsOp, WhileLoopInfo info) {
1450+
// Pattern: x[..., y[idx], z[idx], ...] where idx is an affine function of
1451+
// the loop induction variable. We need to:
1452+
// 1. Hoist the computation of y[idx], z[idx], etc. for all loop iterations
1453+
// 2. Create a gather operation from x using those indices
1454+
// 3. Replace dsOp uses with a dynamic slice into the gather result
1455+
1456+
// Find all start indices that are dependent on the loop and come from inner
1457+
// dynamic slices (through possible reshape)
1458+
SmallVector<int64_t> dependentDims;
1459+
SmallVector<Value> innerSliceOperands;
1460+
SmallVector<SliceInfo<stablehlo::DynamicSliceOp>> innerSliceInfos;
1461+
1462+
for (auto [i, startIndex] : llvm::enumerate(dsOp.getStartIndices())) {
1463+
// Use traverseOperandsForHoisting to classify this operand
1464+
SmallVector<BatchLiftingMode> modes;
1465+
SmallVector<Value> operands;
1466+
SmallVector<SmallVector<int64_t>> dims;
1467+
SmallVector<int64_t> hoisted;
1468+
SmallVector<SliceInfo<stablehlo::DynamicSliceOp>> mapped;
1469+
DenseMap<Value, SmallVector<Operation *>> hoistMap;
1470+
1471+
SmallVector<Value> singleOperand = {startIndex};
1472+
if (!traverseOperandsForHoisting(singleOperand, whileOp, slices, info,
1473+
modes, operands, dims, hoisted, mapped,
1474+
hoistMap)) {
1475+
return false;
1476+
}
1477+
1478+
if (modes[0] == BatchLiftingMode::DYNAMIC_SLICE) {
1479+
dependentDims.push_back(i);
1480+
innerSliceOperands.push_back(operands[0]);
1481+
innerSliceInfos.push_back(mapped[0]);
1482+
}
1483+
}
1484+
1485+
if (dependentDims.empty()) {
1486+
return false;
1487+
}
1488+
1489+
// Get outer operand - it must be constant across iterations (loop invariant)
1490+
Value outerOperand;
1491+
Value dsOperand = dsOp.getOperand();
1492+
SmallVector<Operation *> canBeHoisted;
1493+
if (!info.isConstantAcrossIterations(dsOperand, outerOperand, canBeHoisted,
1494+
true)) {
1495+
return false;
1496+
}
1497+
if (!outerOperand) {
1498+
// The operand is defined inside the loop but is hoistable - hoist it
1499+
DenseMap<Value, SmallVector<Operation *>> hoistMap;
1500+
hoistMap[dsOperand] = canBeHoisted;
1501+
DenseMap<Value, Value> hoistedValues;
1502+
hoistChainOfOps(hoistMap, rewriter, whileOp, info, hoistedValues);
1503+
outerOperand = hoistedValues[dsOperand];
1504+
}
1505+
1506+
// Verify all non-dependent start indices are constant across iterations
1507+
for (auto [i, startIndex] : llvm::enumerate(dsOp.getStartIndices())) {
1508+
if (llvm::is_contained(dependentDims, i)) {
1509+
continue;
1510+
}
1511+
if (!info.isConstantAcrossIterations(startIndex, true)) {
1512+
return false;
1513+
}
1514+
}
1515+
1516+
int64_t numIters = info.getConstantNumIters();
1517+
Location loc = dsOp.getLoc();
1518+
1519+
rewriter.setInsertionPoint(whileOp);
1520+
1521+
// Step 1: Hoist each inner slice operand and construct the gather indices
1522+
// We need to gather all indices for all loop iterations and concatenate them.
1523+
SmallVector<Value> hoistedIndicesList;
1524+
Type hoistedIndicesElemTy;
1525+
1526+
for (size_t idx = 0; idx < dependentDims.size(); idx++) {
1527+
Value hoistedIndices;
1528+
if (!info.hoistOperationFromLoop(
1529+
rewriter, innerSliceOperands[idx], innerSliceInfos[idx].sliceOp,
1530+
innerSliceInfos[idx].dimensions, hoistedIndices)) {
1531+
return false;
1532+
}
1533+
1534+
auto hoistedTy = cast<RankedTensorType>(hoistedIndices.getType());
1535+
if (idx == 0) {
1536+
hoistedIndicesElemTy = hoistedTy.getElementType();
1537+
}
1538+
1539+
// Reshape to [numIters, 1] for use as gather indices
1540+
SmallVector<int64_t> reshapeShape = {numIters, 1};
1541+
auto reshapeTy = RankedTensorType::get(reshapeShape, hoistedIndicesElemTy);
1542+
1543+
// Convert type if needed
1544+
if (hoistedTy.getElementType() != hoistedIndicesElemTy) {
1545+
hoistedIndices = stablehlo::ConvertOp::create(
1546+
rewriter, loc,
1547+
RankedTensorType::get(hoistedTy.getShape(), hoistedIndicesElemTy),
1548+
hoistedIndices);
1549+
}
1550+
1551+
Value reshaped =
1552+
stablehlo::ReshapeOp::create(rewriter, loc, reshapeTy, hoistedIndices);
1553+
hoistedIndicesList.push_back(reshaped);
1554+
}
1555+
1556+
// Concatenate all hoisted indices along the last dimension
1557+
Value gatherIndices;
1558+
if (hoistedIndicesList.size() == 1) {
1559+
gatherIndices = hoistedIndicesList[0];
1560+
} else {
1561+
// Result shape: [numIters, numDependentDims]
1562+
SmallVector<int64_t> concatShape = {numIters,
1563+
(int64_t)dependentDims.size()};
1564+
auto concatTy = RankedTensorType::get(concatShape, hoistedIndicesElemTy);
1565+
gatherIndices = stablehlo::ConcatenateOp::create(
1566+
rewriter, loc, concatTy, hoistedIndicesList, /*dimension=*/1);
1567+
}
1568+
1569+
// Step 2: Create the gather operation from the outer operand
1570+
auto outerOperandTy = cast<RankedTensorType>(outerOperand.getType());
1571+
auto dsSliceSizes = dsOp.getSliceSizes();
1572+
1573+
// The gather slice sizes: dependent dimensions get 1, others get original
1574+
SmallVector<int64_t> gatherSliceSizes;
1575+
for (size_t i = 0; i < dsSliceSizes.size(); i++) {
1576+
if (llvm::is_contained(dependentDims, i)) {
1577+
gatherSliceSizes.push_back(1);
1578+
} else {
1579+
gatherSliceSizes.push_back(dsSliceSizes[i]);
1580+
}
1581+
}
1582+
1583+
// offsetDims: output dimensions corresponding to non-collapsed slice dims
1584+
// Start at 1 since batch dim is at position 0, then consecutive for each
1585+
// non-collapsed dimension
1586+
SmallVector<int64_t> offsetDims;
1587+
int64_t offsetDimIdx = 1; // Start after the batch dimension
1588+
for (size_t i = 0; i < outerOperandTy.getRank(); i++) {
1589+
if (!llvm::is_contained(dependentDims, i)) {
1590+
offsetDims.push_back(offsetDimIdx);
1591+
offsetDimIdx++;
1592+
}
1593+
}
1594+
1595+
// collapsedSliceDims: the dimensions we're indexing into
1596+
SmallVector<int64_t> collapsedSliceDims = llvm::to_vector(dependentDims);
1597+
1598+
// startIndexMap: maps index vector dimensions to operand dimensions
1599+
SmallVector<int64_t> startIndexMap = llvm::to_vector(dependentDims);
1600+
1601+
// Calculate output shape: [numIters, ...sliceSizes for non-dependent dims...]
1602+
SmallVector<int64_t> gatherOutputShape;
1603+
gatherOutputShape.push_back(numIters);
1604+
for (size_t i = 0; i < dsSliceSizes.size(); i++) {
1605+
if (!llvm::is_contained(dependentDims, i)) {
1606+
gatherOutputShape.push_back(dsSliceSizes[i]);
1607+
}
1608+
}
1609+
1610+
auto gatherResultTy =
1611+
RankedTensorType::get(gatherOutputShape, outerOperandTy.getElementType());
1612+
1613+
auto gatherOp = stablehlo::GatherOp::create(
1614+
rewriter, loc, gatherResultTy, outerOperand, gatherIndices,
1615+
stablehlo::GatherDimensionNumbersAttr::get(
1616+
rewriter.getContext(),
1617+
/*offsetDims=*/offsetDims,
1618+
/*collapsedSliceDims=*/collapsedSliceDims,
1619+
/*operandBatchingDims=*/{},
1620+
/*startIndicesBatchingDims=*/{},
1621+
/*startIndexMap=*/startIndexMap,
1622+
/*indexVectorDim=*/1),
1623+
gatherSliceSizes);
1624+
1625+
// Step 3: Replace the dsOp with a dynamic slice into the gather result
1626+
// The dynamic slice will index using the loop induction variable
1627+
rewriter.setInsertionPointAfter(dsOp);
1628+
1629+
auto inductionVar = info.getInductionVariable();
1630+
auto inductionVarType = cast<RankedTensorType>(inductionVar.getType());
1631+
1632+
// Compute the index for the dynamic slice
1633+
Value sliceIndex;
1634+
if (info.isConstantStart() && info.getConstantStart() == 0) {
1635+
sliceIndex = inductionVar;
1636+
} else {
1637+
sliceIndex = stablehlo::SubtractOp::create(rewriter, loc, inductionVar,
1638+
info.getStart());
1639+
}
1640+
if (!info.isStepOne()) {
1641+
sliceIndex = stablehlo::DivOp::create(rewriter, loc, sliceIndex,
1642+
info.getStep(rewriter));
1643+
}
1644+
1645+
// Convert sliceIndex to the same type as gather indices if needed
1646+
if (inductionVarType.getElementType() != hoistedIndicesElemTy) {
1647+
auto newIndexTy = RankedTensorType::get({}, hoistedIndicesElemTy);
1648+
sliceIndex =
1649+
stablehlo::ConvertOp::create(rewriter, loc, newIndexTy, sliceIndex);
1650+
}
1651+
1652+
// Create constZero with the same type as sliceIndex (after conversion)
1653+
auto sliceIndexTy = cast<RankedTensorType>(sliceIndex.getType());
1654+
auto constZero = stablehlo::ConstantOp::create(
1655+
rewriter, loc, sliceIndexTy,
1656+
cast<ElementsAttr>(makeAttr(sliceIndexTy, 0)));
1657+
// Build the start indices for dynamic slice
1658+
SmallVector<Value> dynSliceStarts;
1659+
dynSliceStarts.push_back(sliceIndex);
1660+
for (size_t i = 0; i < dsSliceSizes.size(); i++) {
1661+
if (!llvm::is_contained(dependentDims, i)) {
1662+
dynSliceStarts.push_back(constZero);
1663+
}
1664+
}
1665+
1666+
// Build the slice sizes (1 for the batch dim, original sizes for others)
1667+
SmallVector<int64_t> dynSliceSizes;
1668+
dynSliceSizes.push_back(1);
1669+
for (size_t i = 0; i < dsSliceSizes.size(); i++) {
1670+
if (!llvm::is_contained(dependentDims, i)) {
1671+
dynSliceSizes.push_back(dsSliceSizes[i]);
1672+
}
1673+
}
1674+
1675+
auto dynSlice = stablehlo::DynamicSliceOp::create(
1676+
rewriter, loc, gatherOp.getResult(), dynSliceStarts, dynSliceSizes);
1677+
1678+
// Reshape to match the original dsOp output type
1679+
auto replacement =
1680+
stablehlo::ReshapeOp::create(rewriter, loc, dsOp.getType(), dynSlice);
1681+
1682+
rewriter.replaceOp(dsOp, replacement.getResult());
1683+
1684+
return true;
1685+
}
1686+
14431687
bool liftOperationByBatching(
14441688
PatternRewriter &rewriter, stablehlo::WhileOp whileOp,
14451689
ArrayRef<SliceInfo<stablehlo::DynamicSliceOp>> slices, Operation *op,

src/enzyme_ad/jax/Passes/AutoBatching.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ struct SliceToBatchElementwise : public SliceToBatchBase {
144144
: SliceToBatchBase(CheckElementwise, ctx, benefit) {}
145145
};
146146

147+
bool raiseDynamicSliceToGather(
148+
mlir::PatternRewriter &rewriter, mlir::stablehlo::WhileOp whileOp,
149+
llvm::ArrayRef<SliceInfo<mlir::stablehlo::DynamicSliceOp>> slices,
150+
mlir::stablehlo::DynamicSliceOp dsOp, mlir::enzyme::WhileLoopInfo info);
151+
147152
bool liftOperationByBatching(
148153
mlir::PatternRewriter &rewriter, mlir::stablehlo::WhileOp whileOp,
149154
llvm::ArrayRef<SliceInfo<mlir::stablehlo::DynamicSliceOp>> slices,

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,9 @@ struct LowerWrap
13161316

13171317
stablehlo::ConcatenateOp lowerExtend(enzymexla::ExtendOp extend,
13181318
PatternRewriter &rewriter, bool replace) {
1319+
OpBuilder::InsertionGuard guard(rewriter);
1320+
rewriter.setInsertionPoint(extend);
1321+
13191322
auto loc = extend.getLoc();
13201323
auto operand = extend.getOperand();
13211324

@@ -1376,6 +1379,9 @@ struct LowerExtend
13761379

13771380
stablehlo::ConcatenateOp lowerRotate(enzymexla::RotateOp rotate,
13781381
PatternRewriter &rewriter, bool replace) {
1382+
OpBuilder::InsertionGuard guard(rewriter);
1383+
rewriter.setInsertionPoint(rotate);
1384+
13791385
// sl0[A:end], sl1[0:A]
13801386
auto shard = sdy::getShardingPerValue(rotate);
13811387
SmallVector<int64_t> strides(rotate.getType().getShape().size(), 1);
@@ -28963,6 +28969,12 @@ LogicalResult DUSDSSimplifyWithSomeUpdateOverlapHelper(
2896328969
dusStartIsConstant) {
2896428970
int64_t dsStartInt = dsStartAP.getSExtValue();
2896528971
int64_t dusStartInt = dusStartAP.getSExtValue();
28972+
28973+
if (dusStartInt + updateShape[i] <= dsStartInt) { // no overlap
28974+
canReplace = false;
28975+
break;
28976+
}
28977+
2896628978
sliceStarts[i] = dsStartInt - dusStartInt;
2896728979
continue;
2896828980
}
@@ -29039,6 +29051,7 @@ LogicalResult DUSDSSimplifyWithSomeUpdateOverlapHelper(
2903929051
}
2904029052
}
2904129053

29054+
rewriter.setInsertionPoint(dsOp);
2904229055
auto newDS = stablehlo::DynamicSliceOpCreate(
2904329056
rewriter, dsOp.getLoc(), dusOp.getUpdate(), dynamicSliceStarts,
2904429057
dsOp.getSliceSizes());
@@ -29051,7 +29064,7 @@ LogicalResult DUSDSSimplifyWithSomeUpdateOverlapHelper(
2905129064

2905229065
// simple case
2905329066
if (allOffsetsZero && updateShape == dsSliceSizes) {
29054-
rewriter.replaceAllUsesWith(dsOp.getResult(), dusOp.getUpdate());
29067+
rewriter.replaceOp(dsOp, dusOp.getUpdate());
2905529068
return success();
2905629069
}
2905729070

@@ -29081,6 +29094,7 @@ LogicalResult DUSDSSimplifyWithSomeUpdateOverlapHelper(
2908129094
return failure();
2908229095
}
2908329096

29097+
rewriter.setInsertionPoint(dsOp);
2908429098
Value result =
2908529099
stablehlo::SliceOp::create(rewriter, dsOp.getLoc(), dusOp.getUpdate(),
2908629100
sliceStarts, sliceLimits, sliceStrides);

0 commit comments

Comments
 (0)