diff --git a/src/enzyme_ad/jax/Passes/AutoBatching.cpp b/src/enzyme_ad/jax/Passes/AutoBatching.cpp index 95555bf2d0..82fba83f84 100644 --- a/src/enzyme_ad/jax/Passes/AutoBatching.cpp +++ b/src/enzyme_ad/jax/Passes/AutoBatching.cpp @@ -1576,6 +1576,277 @@ mlir::LogicalResult WhileElementwiseReductionToReduce::matchAndRewriteImpl( return success(anyRewritten); } +LogicalResult +RaiseScanLikeOperations::matchAndRewriteImpl(stablehlo::WhileOp whileOp, + PatternRewriter &rewriter) const { + // %iterArg = %outVal [initial value] + // %iterArg[..., i, ...] = op(%iterArg[..., i - 1, ...], %outVal[..., i, ...]) + // supported ops: mul, add, min, max, .... + + auto &body = whileOp.getBody().front(); + auto returnOp = dyn_cast(body.getTerminator()); + if (!returnOp) { + return failure(); + } + + WhileLoopInfo info(whileOp); + auto computedInfo = info.computeInfo(); + (void)computedInfo; + if (!info.isValid() || !info.isConstant() || + info.getConstantNumIters() <= 0) { + return failure(); + } + + auto step = info.getConstantStep().value(); + if (step != 1) { + return failure(); + } + + auto affineIndexInfo = info.getAffineIndexInfo(); + bool anyRewritten = false; + + for (size_t argIdx = 0; argIdx < whileOp.getNumOperands(); argIdx++) { + auto iterArg = body.getArgument(argIdx); + auto returnVal = returnOp->getOperand(argIdx); + + // Look for a DynamicUpdateSlice that updates this loop-carried value + auto dusOp = returnVal.getDefiningOp(); + if (!dusOp) { + continue; + } + + // The base operand must be the loop-carried argument + if (dusOp.getOperand() != iterArg) { + continue; + } + + // Check if the while result is used + if (whileOp->getResult(argIdx).getUses().empty()) { + continue; + } + + // Get the update value and check if it comes from a reducible op + auto updateVal = dusOp.getUpdate(); + auto updateOp = updateVal.getDefiningOp(); + if (!updateOp || !stablehlo::canFuseIntoReduce(updateOp)) { + continue; + } + + // Check if one operand is a DynamicSlice of the loop variable at + // [i - offset] + auto lhs = updateOp->getOperand(0); + auto rhs = updateOp->getOperand(1); + + auto lhsSlice = lhs.getDefiningOp(); + auto rhsSlice = rhs.getDefiningOp(); + + if (!lhsSlice || !rhsSlice) { + continue; + } + + // One must slice iterArg (the previous value) and the other must slice + // an external input (or be constant across iterations) + stablehlo::DynamicSliceOp prevSlice = nullptr, inputSlice = nullptr; + bool prevIsLhs = false; + + if (lhsSlice.getOperand() == iterArg) { + prevSlice = lhsSlice; + prevIsLhs = true; + // Check if rhs slices something that is constant across iterations + if (!info.isConstantAcrossIterations(rhsSlice.getOperand(), false)) { + continue; + } + inputSlice = rhsSlice; + } else if (rhsSlice.getOperand() == iterArg) { + prevSlice = rhsSlice; + prevIsLhs = false; + if (!info.isConstantAcrossIterations(lhsSlice.getOperand(), false)) { + continue; + } + inputSlice = lhsSlice; + } else { + continue; + } + + if (!prevSlice || !inputSlice) { + continue; + } + + // Find the dimension that varies (the scan dimension) + // The DUS start indices and the previous slice indices should differ by 1 + auto dusStarts = dusOp.getStartIndices(); + auto prevStarts = prevSlice.getStartIndices(); + auto inputStarts = inputSlice.getStartIndices(); + + if (dusStarts.size() != prevStarts.size() || + dusStarts.size() != inputStarts.size()) { + continue; + } + + // Find the scan dimension: the dimension where prevSlice is at offset -step + // relative to DUS + int64_t scanDim = -1; + bool validScanPattern = true, needsReverse = false; + + for (size_t d = 0; d < dusStarts.size(); d++) { + auto dusStart = dusStarts[d]; + auto prevStart = prevStarts[d]; + auto inputStart = inputStarts[d]; + + // Check if this dimension has the scan pattern + // DUS writes at position i, prevSlice reads at position i-step + // inputSlice reads at position i + // So: dusStart == inputStart, prevStart == dusStart - step + + // For non-scan dimensions, all three indices must match + if (dusStart == prevStart && dusStart == inputStart) { + // Non-scan dimension - all indices match, continue to next dim + continue; + } + + // For the scan dimension: dusStart == inputStart, prevStart differs + if (inputStart != dusStart) { // inputStart must always match dusStart + validScanPattern = false; + break; + } + + // At this point: dusStart == inputStart, prevStart != dusStart + // This is the scan dimension candidate - check the offset relationship + bool foundOffset = false; + if (affineIndexInfo.contains(dusStart) && + affineIndexInfo.contains(prevStart)) { + auto dusInfo = affineIndexInfo[dusStart]; + auto prevInfo = affineIndexInfo[prevStart]; + + auto dusInfoScale = dusInfo.scale.getSExtValue(); + auto prevInfoScale = prevInfo.scale.getSExtValue(); + + if (prevInfoScale != dusInfoScale || std::abs(dusInfoScale) != 1) { + validScanPattern = false; + break; + } + + if (dusInfoScale < 0) { + needsReverse = true; + } + + auto offsetDiff = + dusInfo.offset.getSExtValue() - prevInfo.offset.getSExtValue(); + if (offsetDiff == dusInfoScale) { // 1 or -1 + if (scanDim != -1) { + // Multiple scan dimensions not supported + validScanPattern = false; + break; + } + scanDim = d; + foundOffset = true; + } + } + + if (!foundOffset) { + // Could not verify the offset relationship + validScanPattern = false; + break; + } + } + + if (!validScanPattern || scanDim == -1) { + continue; + } + + if (dusOp.getUpdate().getType().getDimSize(scanDim) != 1) { + continue; + } + + // Check that the scale divides the dimension size evenly + auto whileOperand = whileOp->getOperand(argIdx); + auto operandType = cast(whileOperand.getType()); + + auto elemType = operandType.getElementType(); + auto inputShape = operandType.getShape(); + int64_t rank = operandType.getRank(); + + // Create the ReduceWindow operation + rewriter.setInsertionPoint(whileOp); + + // Window dimensions: size 2 in scan dimension, 1 in others + SmallVector windowDims(rank, 1), windowStrides(rank, 1), + baseDilations(rank, 1), windowDilations(rank, 1); + windowDims[scanDim] = inputShape[scanDim]; // cumulative window + + // Padding: [scanDim-1, 0] to make cumulative sum + SmallVector paddingFlat(2 * rank, 0); + paddingFlat[2 * scanDim] = inputShape[scanDim] - 1; + + int64_t paddingShape[2] = {rank, 2}; + + // Get identity value for the operation + Value initVal = stablehlo::getIdentityValue(rewriter, whileOp->getLoc(), + elemType, updateOp); + if (!initVal) { + continue; + } + + // TODO: if users inside the loop replace with a dymamic slice. + + // TODO: correctly handle for partial indexing + + if (needsReverse) { + whileOperand = stablehlo::ReverseOp::create(rewriter, whileOp->getLoc(), + whileOperand, {scanDim}); + } + + auto reduceWindowOp = stablehlo::ReduceWindowOp::create( + rewriter, whileOp->getLoc(), TypeRange{whileOperand.getType()}, + ValueRange{whileOperand}, ValueRange{initVal}, + rewriter.getDenseI64ArrayAttr(windowDims), + rewriter.getDenseI64ArrayAttr(windowStrides), + rewriter.getDenseI64ArrayAttr(baseDilations), + rewriter.getDenseI64ArrayAttr(windowDilations), + DenseIntElementsAttr::get( + RankedTensorType::get(paddingShape, rewriter.getIntegerType(64)), + paddingFlat)); + + // Create the reduce body + auto scalarType = RankedTensorType::get({}, elemType); + Block *block = rewriter.createBlock(&reduceWindowOp.getBody()); + block->addArgument(scalarType, whileOp->getLoc()); + block->addArgument(scalarType, whileOp->getLoc()); + + { + IRRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(block); + OperationState state(updateOp->getLoc(), updateOp->getName()); + state.addTypes(TypeRange{scalarType}); + if (prevIsLhs) { + state.addOperands( + ValueRange{block->getArgument(0), block->getArgument(1)}); + } else { + state.addOperands( + ValueRange{block->getArgument(1), block->getArgument(0)}); + } + auto *newOp = mlir::Operation::create(state); + rewriter.insert(newOp); + stablehlo::ReturnOp::create(rewriter, updateOp->getLoc(), + newOp->getResults()); + } + + Value scanResult = reduceWindowOp->getResult(0); + + if (needsReverse) { + rewriter.setInsertionPointAfter(reduceWindowOp); + scanResult = stablehlo::ReverseOp::create(rewriter, whileOp->getLoc(), + scanResult, {scanDim}); + } + + // Replace uses of the while result with the reduce_window result + rewriter.replaceAllUsesWith(whileOp->getResult(argIdx), scanResult); + anyRewritten = true; + } + + return success(anyRewritten); +} + mlir::LogicalResult RemoveLoopCarriedDependenciesFromWhileLoadOperations::matchAndRewriteImpl( stablehlo::WhileOp whileOp, PatternRewriter &rewriter) const { @@ -2101,6 +2372,10 @@ void populateAutoBatchingPassPatterns(RewritePatternSet &patterns, if (options.enableRemoveLoopCarriedDependenciesFromWhileLoadOperations) { patterns.add(ctx); } + + if (options.enableWhileRaiseScanLikeOperations) { + patterns.add(ctx); + } } } // namespace enzyme @@ -2120,7 +2395,8 @@ struct AutoBatchingPass while_loop_batching_mode, while_elementwise_reduction_to_reduce_passes, while_is_copy_simplify_passes, - while_remove_loop_carried_dependencies_from_load_operations}; + while_remove_loop_carried_dependencies_from_load_operations, + while_raise_scan_like_operations}; mlir::enzyme::populateAutoBatchingPassPatterns(patterns, context, options); GreedyRewriteConfig config; diff --git a/src/enzyme_ad/jax/Passes/AutoBatching.h b/src/enzyme_ad/jax/Passes/AutoBatching.h index 79776b5cee..103d2ea3b2 100644 --- a/src/enzyme_ad/jax/Passes/AutoBatching.h +++ b/src/enzyme_ad/jax/Passes/AutoBatching.h @@ -284,6 +284,19 @@ struct WhileIsCopySimplify mlir::enzyme::WhileLoopInfo &info) const; }; +struct RaiseScanLikeOperations + : public mlir::enzyme::CheckedOpRewritePattern { + using Base = mlir::enzyme::CheckedOpRewritePattern; + + using Base::Base; + + mlir::LogicalResult + matchAndRewriteImpl(mlir::stablehlo::WhileOp whileOp, + mlir::PatternRewriter &rewriter) const; +}; + namespace mlir { namespace enzyme { @@ -294,6 +307,7 @@ struct AutoBatchingPassPipelineOptions { bool enableWhileElementwiseReductionToReduce; bool enableWhileIsCopySimplify; bool enableRemoveLoopCarriedDependenciesFromWhileLoadOperations; + bool enableWhileRaiseScanLikeOperations; }; void populateAutoBatchingPassPatterns(RewritePatternSet &patterns, diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index b4fedb7078..081a36d4b9 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -27925,7 +27925,7 @@ struct EnzymeHLOOptPass if (enable_auto_batching_passes) { mlir::enzyme::AutoBatchingPassPipelineOptions options{ - true, true, "greedy", true, true, true}; + true, true, "greedy", true, true, true, true}; mlir::enzyme::populateAutoBatchingPassPatterns(patterns, context, options); } diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 3cc42b9baa..0db71af312 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1081,6 +1081,12 @@ def AutoBatchingPass : Pass<"auto-batching"> { /*type=*/"bool", /*default=*/"true", /*description=*/"remove loop carried deps from load operations">, + Option< + /*C++ variable name=*/"while_raise_scan_like_operations", + /*CLI argument=*/"while_raise_scan_like_operations", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"remove loop carried deps from load operations">, Option< /*C++ variable name=*/"max_iterations", /*CLI argument=*/"max_iterations", diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index 9b5f83ff5c..333a40c9b0 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -2542,6 +2542,11 @@ def ApplyRemoveLoopCarriedDependenciesFromWhileLoadOperationsPatterns : EnzymeHL let patterns = ["RemoveLoopCarriedDependenciesFromWhileLoadOperations"]; } +def ApplyRaiseScanLikeOperationsPatterns : EnzymeHLOPatternOp< + "raise_scan_like_operations"> { + let patterns = ["RaiseScanLikeOperations"]; +} + def EnzymeHLOUnroll : EnzymeHLOParameterizedPatternOp< "enzyme_hlo_unroll"> { let arguments = (ins OptionalAttr:$benefit, I64Attr:$parameter); diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 8e9a10bc74..432e1257cd 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -428,6 +428,7 @@ def optimization_passes( "greedy_while_loop_batch_fission", "while_elementwise_reduction_to_reduce", "remove_loop_carried_dependencies_from_while_load_operations", + "raise_scan_like_operations", ] if enable_licm_optimization_passes: diff --git a/test/lit_tests/autobatching/scanlike.mlir b/test/lit_tests/autobatching/scanlike.mlir new file mode 100644 index 0000000000..8ce55a9bef --- /dev/null +++ b/test/lit_tests/autobatching/scanlike.mlir @@ -0,0 +1,57 @@ +module { + func.func @main(%arg0: tensor<128xf32>) -> tensor<128xf32> { + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<0> : tensor + %c_1 = stablehlo.constant dense<1> : tensor + %c_2 = stablehlo.constant dense<2> : tensor + %c_3 = stablehlo.constant dense<127> : tensor + %0:2 = stablehlo.while(%iterArg = %c_0, %iterArg_4 = %arg0) : tensor, tensor<128xf32> + cond { + %1 = stablehlo.compare LT, %iterArg, %c_3 : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + } do { + %1 = stablehlo.add %c_2, %iterArg {enzymexla.bounds = [[2, 128]]} : tensor + %2 = stablehlo.add %iterArg, %c_1 {enzymexla.bounds = [[1, 127]]} : tensor + %3 = stablehlo.convert %1 {enzymexla.bounds = [[2, 128]]} : (tensor) -> tensor + %4 = stablehlo.subtract %3, %c {enzymexla.bounds = [[1, 127]]} : tensor + %5 = stablehlo.dynamic_slice %arg0, %4, sizes = [1] : (tensor<128xf32>, tensor) -> tensor<1xf32> + %6 = stablehlo.subtract %1, %c_1 {enzymexla.bounds = [[1, 127]]} : tensor + %7 = stablehlo.convert %6 {enzymexla.bounds = [[1, 127]]} : (tensor) -> tensor + %8 = stablehlo.subtract %7, %c {enzymexla.bounds = [[0, 126]]} : tensor + %9 = stablehlo.dynamic_slice %iterArg_4, %8, sizes = [1] : (tensor<128xf32>, tensor) -> tensor<1xf32> + %10 = stablehlo.add %5, %9 : tensor<1xf32> + %11 = stablehlo.dynamic_update_slice %iterArg_4, %10, %4 : (tensor<128xf32>, tensor<1xf32>, tensor) -> tensor<128xf32> + stablehlo.return %2, %11 : tensor, tensor<128xf32> + } + return %0#1 : tensor<128xf32> + } +} + +module @reactant_looped_... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} { + func.func @main(%arg0: tensor<128xf32> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}) -> tensor<128xf32> attributes {enzymexla.memory_effects = []} { + %c = stablehlo.constant dense<1> : tensor + %c_0 = stablehlo.constant dense<1> : tensor + %c_1 = stablehlo.constant dense<0> : tensor + %c_2 = stablehlo.constant dense<128> : tensor + %c_3 = stablehlo.constant dense<127> : tensor + %0:2 = stablehlo.while(%iterArg = %c_1, %iterArg_4 = %arg0) : tensor, tensor<128xf32> attributes {enzyme.disable_mincut} + cond { + %1 = stablehlo.compare LT, %iterArg, %c_3 : (tensor, tensor) -> tensor + stablehlo.return %1 : tensor + } do { + %1 = stablehlo.subtract %c_2, %iterArg {enzymexla.bounds = [[2, 128]]} : tensor + %2 = stablehlo.add %iterArg, %c_0 {enzymexla.bounds = [[1, 127]]} : tensor + %3 = stablehlo.convert %1 {enzymexla.bounds = [[2, 128]]} : (tensor) -> tensor + %4 = stablehlo.subtract %3, %c {enzymexla.bounds = [[1, 127]]} : tensor + %5 = stablehlo.dynamic_slice %arg0, %4, sizes = [1] : (tensor<128xf32>, tensor) -> tensor<1xf32> + %6 = stablehlo.add %1, %c_0 {enzymexla.bounds = [[3, 129]]} : tensor + %7 = stablehlo.convert %6 {enzymexla.bounds = [[3, 129]]} : (tensor) -> tensor + %8 = stablehlo.subtract %7, %c {enzymexla.bounds = [[2, 128]]} : tensor + %9 = stablehlo.dynamic_slice %iterArg_4, %8, sizes = [1] : (tensor<128xf32>, tensor) -> tensor<1xf32> + %10 = stablehlo.add %5, %9 : tensor<1xf32> + %11 = stablehlo.dynamic_update_slice %iterArg_4, %10, %4 : (tensor<128xf32>, tensor<1xf32>, tensor) -> tensor<128xf32> + stablehlo.return %2, %11 : tensor, tensor<128xf32> + } + return %0#1 : tensor<128xf32> + } +}