diff --git a/src/enzyme_ad/jax/Passes/EnzymeBatchToStableHLOPass.cpp b/src/enzyme_ad/jax/Passes/EnzymeBatchToStableHLOPass.cpp new file mode 100644 index 0000000000..a583fb428c --- /dev/null +++ b/src/enzyme_ad/jax/Passes/EnzymeBatchToStableHLOPass.cpp @@ -0,0 +1,122 @@ +//===- EnzymeBatchToStableHLOPass.cpp ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to print the MLIR module +//===----------------------------------------------------------------------===// + +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Dialect/Ops.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" +#include "src/enzyme_ad/jax/Utils.h" +#include "stablehlo/dialect/StablehloOps.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_ENZYMEBATCHTOSTABLEHLOPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; +using namespace enzyme; +namespace { + +struct ExtractOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(enzyme::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto inTy = op.getInput().getType(); + auto outTy = op.getOutput().getType(); + auto outRankTy = dyn_cast(outTy); + // stablehlo always has tensor type + auto inRankTy = dyn_cast(inTy); + auto ndims = inRankTy.getRank(); // is atleast 1 + + if (ndims < 1) + return failure(); + + // static slice + SmallVector start_indices; + start_indices.push_back(op.getIndex()); + for (int i = 1; i < ndims; ++i) { + start_indices.push_back(0); + } + SmallVector limit_indices; + limit_indices.push_back(op.getIndex() + 1); + limit_indices.append(outRankTy.getShape().begin(), + outRankTy.getShape().end()); + SmallVector strides(ndims, 1); + + Value slicedOut = + stablehlo::SliceOp::create(rewriter, op->getLoc(), op.getInput(), + start_indices, limit_indices, strides); + // reshape slicedOut to our final Op + rewriter.replaceOpWithNewOp(op, outTy, slicedOut); + return success(); + } +}; + +struct ConcatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(enzyme::ConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector inputs = op.getInputs(); + if (inputs.empty()) + return failure(); + + // stablehlo always has tensor type + // reshape each input to 1xinput_rank and concatenate on dim=0 + + SmallVector expandedInputs; + for (Value in : inputs) { + auto inRankTy = cast(in.getType()); + auto inShape = inRankTy.getShape(); + SmallVector newInShape = {1}; + newInShape.append(inShape.begin(), inShape.end()); + auto newInTy = inRankTy.clone(newInShape); + Value newInput = + stablehlo::ReshapeOp::create(rewriter, op->getLoc(), newInTy, in); + expandedInputs.push_back(newInput); + } + + // concatenate on dim=0 + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), expandedInputs, /*dim=*/0); + return success(); + } +}; + +struct EnzymeBatchToStableHLOPass + : public enzyme::impl::EnzymeBatchToStableHLOPassBase< + EnzymeBatchToStableHLOPass> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalOp(); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } + }; +}; +} // namespace diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 06c2259c63..3c6db53854 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -1008,4 +1008,13 @@ def SCFCPUify : Pass<"cpuify"> { Option<"method", "method", "std::string", /*default=*/"\"distribute\"", "Method of doing distribution"> ]; } + +def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> { + let summary = "Legalize batching specific enzyme ops to stablehlo dialect"; + let dependentDialects = [ + "stablehlo::StablehloDialect", + "enzyme::EnzymeDialect" + ]; +} + #endif diff --git a/test/lit_tests/OptimizeAD/bwd_batch.mlir b/test/lit_tests/OptimizeAD/bwd_batch.mlir new file mode 100644 index 0000000000..3bcceb3af9 --- /dev/null +++ b/test/lit_tests/OptimizeAD/bwd_batch.mlir @@ -0,0 +1,71 @@ +// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch %s | FileCheck %s +// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch --enzyme-batch-to-tensor %s | FileCheck %s --check-prefix=LEGAL + +//1. Scalar test +module { + func.func @square(%x : tensor) -> tensor{ + %y = stablehlo.multiply %x, %x : tensor + return %y : tensor + } + func.func @test1(%x : tensor, %dr1 : tensor, %dr2 : tensor) -> (tensor,tensor) { + %r, %dx1 = enzyme.autodiff @square(%x, %dr1) { activity=[#enzyme], ret_activity=[#enzyme] } : (tensor, tensor) -> (tensor,tensor) + %r2, %dx2 = enzyme.autodiff @square(%x, %dr2) { activity=[#enzyme], ret_activity=[#enzyme] } : (tensor, tensor) -> (tensor,tensor) + return %dx1,%dx2 : tensor, tensor + } +} + +// CHECK-LABEL: func.func @test1 +// CHECK-SAME: (%[[PRIMAL:.*]]: tensor, %[[DIFF1:.*]]: tensor, %[[DIFF2:.*]]: tensor) -> (tensor, tensor) +// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor, tensor) -> tensor<2xf64> +// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor, tensor<2xf64>) -> (tensor, tensor<2xf64>) +// CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[0] : (tensor<2xf64>) -> tensor +// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[1] : (tensor<2xf64>) -> tensor +// CHECK-NEXT: return %[[RES0]], %[[RES1]] + +// LEGAL-LABEL: func.func @test1 +// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor, %[[DIFF1:.*]]: tensor, %[[DIFF2:.*]]: tensor) -> (tensor, tensor) +// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor) -> tensor<1xf64> +// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor) -> tensor<1xf64> +// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1xf64>, tensor<1xf64>) -> tensor<2xf64> +// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor, tensor<2xf64>) -> (tensor, tensor<2xf64>) +// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64> +// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1xf64>) -> tensor +// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64> +// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1xf64>) -> tensor +// LEGAL-NEXT: return %[[RES0]], %[[RES1]] + +// ----- + +//2. Tensor test +module { + func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ + %y = stablehlo.multiply %x, %x : tensor<10xf64> + return %y : tensor<10xf64> + } + func.func @test2(%x : tensor<10xf64>, %dr1 : tensor<10xf64>, %dr2 : tensor<10xf64>) -> (tensor<10xf64>,tensor<10xf64>) { + %r, %dx1 = enzyme.autodiff @square(%x, %dr1) { activity=[#enzyme], ret_activity=[#enzyme]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) + %r2, %dx2 = enzyme.autodiff @square(%x, %dr2) { activity=[#enzyme], ret_activity=[#enzyme]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) + return %dx1,%dx2 : tensor<10xf64>,tensor<10xf64> + } +} + + +// CHECK-LABEL: func.func @test2 +// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) +// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64> +// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>) +// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[0] : (tensor<2x10xf64>) -> tensor<10xf64> +// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[1] : (tensor<2x10xf64>) -> tensor<10xf64> +// CHECK-NEXT: return %[[RES0]], %[[RES1]] + +// LEGAL-LABEL: func.func @test2 +// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) +// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<10xf64>) -> tensor<1x10xf64> +// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<10xf64>) -> tensor<1x10xf64> +// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64> +// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>) +// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [0:1, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64> +// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1x10xf64>) -> tensor<10xf64> +// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [1:2, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64> +// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1x10xf64>) -> tensor<10xf64> +// LEGAL-NEXT: return %[[RES0]], %[[RES1]] diff --git a/test/lit_tests/OptimizeAD/fwd_batch.mlir b/test/lit_tests/OptimizeAD/fwd_batch.mlir new file mode 100644 index 0000000000..c48f8500b3 --- /dev/null +++ b/test/lit_tests/OptimizeAD/fwd_batch.mlir @@ -0,0 +1,70 @@ +// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch %s | FileCheck %s +// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch --enzyme-batch-to-tensor %s | FileCheck %s --check-prefix=LEGAL +// 1. Scalar test +module { + func.func @square(%x : tensor) -> tensor{ + %y = stablehlo.multiply %x, %x : tensor + return %y : tensor + } + func.func @test1(%x : tensor, %dx1 : tensor, %dx2 : tensor) -> (tensor,tensor) { + %r1 = enzyme.fwddiff @square(%x, %dx1) { activity=[#enzyme], ret_activity=[#enzyme]} : (tensor, tensor) -> (tensor) + %r2 = enzyme.fwddiff @square(%x, %dx2) { activity=[#enzyme], ret_activity=[#enzyme] } : (tensor, tensor) -> (tensor) + return %r1,%r2 : tensor, tensor + } +} + +// CHECK-LABEL: func.func @test1 +// CHECK-SAME: (%[[PRIMAL:.*]]: tensor, %[[DIFF1:.*]]: tensor, %[[DIFF2:.*]]: tensor) -> (tensor, tensor) +// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor, tensor) -> tensor<2xf64> +// CHECK: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor, tensor<2xf64>) -> tensor<2xf64> +// CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][0] : (tensor<2xf64>) -> tensor +// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][1] : (tensor<2xf64>) -> tensor +// CHECK-NEXT: return %[[RES0]], %[[RES1]] + +// LEGAL-LABEL: func.func @test1 +// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor, %[[DIFF1:.*]]: tensor, %[[DIFF2:.*]]: tensor) -> (tensor, tensor) +// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor) -> tensor<1xf64> +// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor) -> tensor<1xf64> +// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1xf64>, tensor<1xf64>) -> tensor<2xf64> +// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor, tensor<2xf64>) -> tensor<2xf64> +// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES]] [0:1] : (tensor<2xf64>) -> tensor<1xf64> +// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1xf64>) -> tensor +// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES]] [1:2] : (tensor<2xf64>) -> tensor<1xf64> +// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1xf64>) -> tensor +// LEGAL-NEXT: return %[[RES0]], %[[RES1]] + +// ----- + +// 2. Tensor test +module { + func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ + %y = stablehlo.multiply %x, %x : tensor<10xf64> + return %y : tensor<10xf64> + } + func.func @test2(%x : tensor<10xf64>, %dx : tensor<10xf64>, %dx2 : tensor<10xf64>) -> (tensor<10xf64>,tensor<10xf64>) { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>) + %r2 = enzyme.fwddiff @square(%x, %dx2) { activity=[#enzyme], ret_activity=[#enzyme] } : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>) + return %r,%r2 : tensor<10xf64>,tensor<10xf64> + } +} + + +// CHECK-LABEL: func.func @test2 +// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) +// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64> +// CHECK: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64> +// CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][0] : (tensor<2x10xf64>) -> tensor<10xf64> +// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][1] : (tensor<2x10xf64>) -> tensor<10xf64> +// CHECK-NEXT: return %[[RES0]], %[[RES1]] + +// LEGAL-LABEL: func.func @test2 +// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>) +// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<10xf64>) -> tensor<1x10xf64> +// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<10xf64>) -> tensor<1x10xf64> +// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64> +// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64> +// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES]] [0:1, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64> +// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1x10xf64>) -> tensor<10xf64> +// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES]] [1:2, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64> +// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1x10xf64>) -> tensor<10xf64> +// LEGAL-NEXT: return %[[RES0]], %[[RES1]]