Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 277 additions & 1 deletion src/enzyme_ad/jax/Passes/AutoBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,277 @@ mlir::LogicalResult WhileElementwiseReductionToReduce::matchAndRewriteImpl(
return success(anyRewritten);
}

LogicalResult
RaiseScanLikeOperations::matchAndRewriteImpl(stablehlo::WhileOp whileOp,
PatternRewriter &rewriter) const {
// %iterArg = %outVal [initial value]
// %iterArg[..., i, ...] = op(%iterArg[..., i - 1, ...], %outVal[..., i, ...])
// supported ops: mul, add, min, max, ....

auto &body = whileOp.getBody().front();
auto returnOp = dyn_cast<stablehlo::ReturnOp>(body.getTerminator());
if (!returnOp) {
return failure();
}

WhileLoopInfo info(whileOp);
auto computedInfo = info.computeInfo();
(void)computedInfo;
if (!info.isValid() || !info.isConstant() ||
info.getConstantNumIters() <= 0) {
return failure();
}

auto step = info.getConstantStep().value();
if (step != 1) {
return failure();
}

auto affineIndexInfo = info.getAffineIndexInfo();
bool anyRewritten = false;

for (size_t argIdx = 0; argIdx < whileOp.getNumOperands(); argIdx++) {
auto iterArg = body.getArgument(argIdx);
auto returnVal = returnOp->getOperand(argIdx);

// Look for a DynamicUpdateSlice that updates this loop-carried value
auto dusOp = returnVal.getDefiningOp<stablehlo::DynamicUpdateSliceOp>();
if (!dusOp) {
continue;
}

// The base operand must be the loop-carried argument
if (dusOp.getOperand() != iterArg) {
continue;
}

// Check if the while result is used
if (whileOp->getResult(argIdx).getUses().empty()) {
continue;
}

// Get the update value and check if it comes from a reducible op
auto updateVal = dusOp.getUpdate();
auto updateOp = updateVal.getDefiningOp();
if (!updateOp || !stablehlo::canFuseIntoReduce(updateOp)) {
continue;
}

// Check if one operand is a DynamicSlice of the loop variable at
// [i - offset]
auto lhs = updateOp->getOperand(0);
auto rhs = updateOp->getOperand(1);

auto lhsSlice = lhs.getDefiningOp<stablehlo::DynamicSliceOp>();
auto rhsSlice = rhs.getDefiningOp<stablehlo::DynamicSliceOp>();

if (!lhsSlice || !rhsSlice) {
continue;
}

// One must slice iterArg (the previous value) and the other must slice
// an external input (or be constant across iterations)
stablehlo::DynamicSliceOp prevSlice = nullptr, inputSlice = nullptr;
bool prevIsLhs = false;

if (lhsSlice.getOperand() == iterArg) {
prevSlice = lhsSlice;
prevIsLhs = true;
// Check if rhs slices something that is constant across iterations
if (!info.isConstantAcrossIterations(rhsSlice.getOperand(), false)) {
continue;
}
inputSlice = rhsSlice;
} else if (rhsSlice.getOperand() == iterArg) {
prevSlice = rhsSlice;
prevIsLhs = false;
if (!info.isConstantAcrossIterations(lhsSlice.getOperand(), false)) {
continue;
}
inputSlice = lhsSlice;
} else {
continue;
}

if (!prevSlice || !inputSlice) {
continue;
}

// Find the dimension that varies (the scan dimension)
// The DUS start indices and the previous slice indices should differ by 1
auto dusStarts = dusOp.getStartIndices();
auto prevStarts = prevSlice.getStartIndices();
auto inputStarts = inputSlice.getStartIndices();

if (dusStarts.size() != prevStarts.size() ||
dusStarts.size() != inputStarts.size()) {
continue;
}

// Find the scan dimension: the dimension where prevSlice is at offset -step
// relative to DUS
int64_t scanDim = -1;
bool validScanPattern = true, needsReverse = false;

for (size_t d = 0; d < dusStarts.size(); d++) {
auto dusStart = dusStarts[d];
auto prevStart = prevStarts[d];
auto inputStart = inputStarts[d];

// Check if this dimension has the scan pattern
// DUS writes at position i, prevSlice reads at position i-step
// inputSlice reads at position i
// So: dusStart == inputStart, prevStart == dusStart - step

// For non-scan dimensions, all three indices must match
if (dusStart == prevStart && dusStart == inputStart) {
// Non-scan dimension - all indices match, continue to next dim
continue;
}

// For the scan dimension: dusStart == inputStart, prevStart differs
if (inputStart != dusStart) { // inputStart must always match dusStart
validScanPattern = false;
break;
}

// At this point: dusStart == inputStart, prevStart != dusStart
// This is the scan dimension candidate - check the offset relationship
bool foundOffset = false;
if (affineIndexInfo.contains(dusStart) &&
affineIndexInfo.contains(prevStart)) {
auto dusInfo = affineIndexInfo[dusStart];
auto prevInfo = affineIndexInfo[prevStart];

auto dusInfoScale = dusInfo.scale.getSExtValue();
auto prevInfoScale = prevInfo.scale.getSExtValue();

if (prevInfoScale != dusInfoScale || std::abs(dusInfoScale) != 1) {
validScanPattern = false;
break;
}

if (dusInfoScale < 0) {
needsReverse = true;
}

auto offsetDiff =
dusInfo.offset.getSExtValue() - prevInfo.offset.getSExtValue();
if (offsetDiff == dusInfoScale) { // 1 or -1
if (scanDim != -1) {
// Multiple scan dimensions not supported
validScanPattern = false;
break;
}
scanDim = d;
foundOffset = true;
}
}

if (!foundOffset) {
// Could not verify the offset relationship
validScanPattern = false;
break;
}
}

if (!validScanPattern || scanDim == -1) {
continue;
}

if (dusOp.getUpdate().getType().getDimSize(scanDim) != 1) {
continue;
}

// Check that the scale divides the dimension size evenly
auto whileOperand = whileOp->getOperand(argIdx);
auto operandType = cast<RankedTensorType>(whileOperand.getType());

auto elemType = operandType.getElementType();
auto inputShape = operandType.getShape();
int64_t rank = operandType.getRank();

// Create the ReduceWindow operation
rewriter.setInsertionPoint(whileOp);

// Window dimensions: size 2 in scan dimension, 1 in others
SmallVector<int64_t> windowDims(rank, 1), windowStrides(rank, 1),
baseDilations(rank, 1), windowDilations(rank, 1);
windowDims[scanDim] = inputShape[scanDim]; // cumulative window

// Padding: [scanDim-1, 0] to make cumulative sum
SmallVector<int64_t> paddingFlat(2 * rank, 0);
paddingFlat[2 * scanDim] = inputShape[scanDim] - 1;

int64_t paddingShape[2] = {rank, 2};

// Get identity value for the operation
Value initVal = stablehlo::getIdentityValue(rewriter, whileOp->getLoc(),
elemType, updateOp);
if (!initVal) {
continue;
}

// TODO: if users inside the loop replace with a dymamic slice.

// TODO: correctly handle for partial indexing

if (needsReverse) {
whileOperand = stablehlo::ReverseOp::create(rewriter, whileOp->getLoc(),
whileOperand, {scanDim});
}

auto reduceWindowOp = stablehlo::ReduceWindowOp::create(
rewriter, whileOp->getLoc(), TypeRange{whileOperand.getType()},
ValueRange{whileOperand}, ValueRange{initVal},
rewriter.getDenseI64ArrayAttr(windowDims),
rewriter.getDenseI64ArrayAttr(windowStrides),
rewriter.getDenseI64ArrayAttr(baseDilations),
rewriter.getDenseI64ArrayAttr(windowDilations),
DenseIntElementsAttr::get(
RankedTensorType::get(paddingShape, rewriter.getIntegerType(64)),
paddingFlat));

// Create the reduce body
auto scalarType = RankedTensorType::get({}, elemType);
Block *block = rewriter.createBlock(&reduceWindowOp.getBody());
block->addArgument(scalarType, whileOp->getLoc());
block->addArgument(scalarType, whileOp->getLoc());

{
IRRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(block);
OperationState state(updateOp->getLoc(), updateOp->getName());
state.addTypes(TypeRange{scalarType});
if (prevIsLhs) {
state.addOperands(
ValueRange{block->getArgument(0), block->getArgument(1)});
} else {
state.addOperands(
ValueRange{block->getArgument(1), block->getArgument(0)});
}
auto *newOp = mlir::Operation::create(state);
rewriter.insert(newOp);
stablehlo::ReturnOp::create(rewriter, updateOp->getLoc(),
newOp->getResults());
}

Value scanResult = reduceWindowOp->getResult(0);

if (needsReverse) {
rewriter.setInsertionPointAfter(reduceWindowOp);
scanResult = stablehlo::ReverseOp::create(rewriter, whileOp->getLoc(),
scanResult, {scanDim});
}

// Replace uses of the while result with the reduce_window result
rewriter.replaceAllUsesWith(whileOp->getResult(argIdx), scanResult);
anyRewritten = true;
}

return success(anyRewritten);
}

mlir::LogicalResult
RemoveLoopCarriedDependenciesFromWhileLoadOperations::matchAndRewriteImpl(
stablehlo::WhileOp whileOp, PatternRewriter &rewriter) const {
Expand Down Expand Up @@ -2101,6 +2372,10 @@ void populateAutoBatchingPassPatterns(RewritePatternSet &patterns,
if (options.enableRemoveLoopCarriedDependenciesFromWhileLoadOperations) {
patterns.add<RemoveLoopCarriedDependenciesFromWhileLoadOperations>(ctx);
}

if (options.enableWhileRaiseScanLikeOperations) {
patterns.add<RaiseScanLikeOperations>(ctx);
}
}

} // namespace enzyme
Expand All @@ -2120,7 +2395,8 @@ struct AutoBatchingPass
while_loop_batching_mode,
while_elementwise_reduction_to_reduce_passes,
while_is_copy_simplify_passes,
while_remove_loop_carried_dependencies_from_load_operations};
while_remove_loop_carried_dependencies_from_load_operations,
while_raise_scan_like_operations};
mlir::enzyme::populateAutoBatchingPassPatterns(patterns, context, options);

GreedyRewriteConfig config;
Expand Down
14 changes: 14 additions & 0 deletions src/enzyme_ad/jax/Passes/AutoBatching.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,19 @@ struct WhileIsCopySimplify
mlir::enzyme::WhileLoopInfo &info) const;
};

struct RaiseScanLikeOperations
: public mlir::enzyme::CheckedOpRewritePattern<mlir::stablehlo::WhileOp,
RaiseScanLikeOperations> {
using Base = mlir::enzyme::CheckedOpRewritePattern<mlir::stablehlo::WhileOp,
RaiseScanLikeOperations>;

using Base::Base;

mlir::LogicalResult
matchAndRewriteImpl(mlir::stablehlo::WhileOp whileOp,
mlir::PatternRewriter &rewriter) const;
};

namespace mlir {
namespace enzyme {

Expand All @@ -294,6 +307,7 @@ struct AutoBatchingPassPipelineOptions {
bool enableWhileElementwiseReductionToReduce;
bool enableWhileIsCopySimplify;
bool enableRemoveLoopCarriedDependenciesFromWhileLoadOperations;
bool enableWhileRaiseScanLikeOperations;
};

void populateAutoBatchingPassPatterns(RewritePatternSet &patterns,
Expand Down
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27925,7 +27925,7 @@ struct EnzymeHLOOptPass

if (enable_auto_batching_passes) {
mlir::enzyme::AutoBatchingPassPipelineOptions options{
true, true, "greedy", true, true, true};
true, true, "greedy", true, true, true, true};
mlir::enzyme::populateAutoBatchingPassPatterns(patterns, context,
options);
}
Expand Down
6 changes: 6 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,12 @@ def AutoBatchingPass : Pass<"auto-batching"> {
/*type=*/"bool",
/*default=*/"true",
/*description=*/"remove loop carried deps from load operations">,
Option<
/*C++ variable name=*/"while_raise_scan_like_operations",
/*CLI argument=*/"while_raise_scan_like_operations",
/*type=*/"bool",
/*default=*/"true",
/*description=*/"remove loop carried deps from load operations">,
Option<
/*C++ variable name=*/"max_iterations",
/*CLI argument=*/"max_iterations",
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2542,6 +2542,11 @@ def ApplyRemoveLoopCarriedDependenciesFromWhileLoadOperationsPatterns : EnzymeHL
let patterns = ["RemoveLoopCarriedDependenciesFromWhileLoadOperations"];
}

def ApplyRaiseScanLikeOperationsPatterns : EnzymeHLOPatternOp<
"raise_scan_like_operations"> {
let patterns = ["RaiseScanLikeOperations"];
}

def EnzymeHLOUnroll : EnzymeHLOParameterizedPatternOp<
"enzyme_hlo_unroll"> {
let arguments = (ins OptionalAttr<I64Attr>:$benefit, I64Attr:$parameter);
Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def optimization_passes(
"greedy_while_loop_batch_fission",
"while_elementwise_reduction_to_reduce",
"remove_loop_carried_dependencies_from_while_load_operations",
"raise_scan_like_operations",
]

if enable_licm_optimization_passes:
Expand Down
Loading
Loading