Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeBatchToStableHLOPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
//===- EnzymeBatchToStableHLOPass.cpp ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to print the MLIR module
//===----------------------------------------------------------------------===//

#include "Enzyme/MLIR/Dialect/Dialect.h"
#include "Enzyme/MLIR/Dialect/Ops.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "src/enzyme_ad/jax/Utils.h"
#include "stablehlo/dialect/StablehloOps.h"

#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_ENZYMEBATCHTOSTABLEHLOPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::enzyme;
using namespace enzyme;
namespace {

struct ExtractOpConversion : public OpConversionPattern<enzyme::ExtractOp> {
using OpConversionPattern<enzyme::ExtractOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(enzyme::ExtractOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto inTy = op.getInput().getType();
auto outTy = op.getOutput().getType();
auto outRankTy = dyn_cast<RankedTensorType>(outTy);
// stablehlo always has tensor type
auto inRankTy = dyn_cast<RankedTensorType>(inTy);
auto ndims = inRankTy.getRank(); // is atleast 1

if (ndims < 1)
return failure();

// static slice
SmallVector<int64_t> start_indices;
start_indices.push_back(op.getIndex());
for (int i = 1; i < ndims; ++i) {
start_indices.push_back(0);
}
SmallVector<int64_t> limit_indices;
limit_indices.push_back(op.getIndex() + 1);
limit_indices.append(outRankTy.getShape().begin(),
outRankTy.getShape().end());
SmallVector<int64_t> strides(ndims, 1);

Value slicedOut =
stablehlo::SliceOp::create(rewriter, op->getLoc(), op.getInput(),
start_indices, limit_indices, strides);
// reshape slicedOut to our final Op
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, outTy, slicedOut);
return success();
}
};

struct ConcatOpConversion : public OpConversionPattern<enzyme::ConcatOp> {
using OpConversionPattern<enzyme::ConcatOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(enzyme::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> inputs = op.getInputs();
if (inputs.empty())
return failure();

// stablehlo always has tensor type
// reshape each input to 1xinput_rank and concatenate on dim=0

SmallVector<Value> expandedInputs;
for (Value in : inputs) {
auto inRankTy = cast<RankedTensorType>(in.getType());
auto inShape = inRankTy.getShape();
SmallVector<int64_t> newInShape = {1};
newInShape.append(inShape.begin(), inShape.end());
auto newInTy = inRankTy.clone(newInShape);
Value newInput =
stablehlo::ReshapeOp::create(rewriter, op->getLoc(), newInTy, in);
expandedInputs.push_back(newInput);
}

// concatenate on dim=0
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
op, op->getResultTypes(), expandedInputs, /*dim=*/0);
return success();
}
};

struct EnzymeBatchToStableHLOPass
: public enzyme::impl::EnzymeBatchToStableHLOPassBase<
EnzymeBatchToStableHLOPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<ConcatOpConversion, ExtractOpConversion>(context);

ConversionTarget target(*context);
target.addLegalDialect<stablehlo::StablehloDialect>();
target.addLegalDialect<enzyme::EnzymeDialect>();
target.addIllegalOp<enzyme::ConcatOp, enzyme::ExtractOp>();

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
}
};
};
} // namespace
9 changes: 9 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1008,4 +1008,13 @@ def SCFCPUify : Pass<"cpuify"> {
Option<"method", "method", "std::string", /*default=*/"\"distribute\"", "Method of doing distribution">
];
}

def EnzymeBatchToStableHLOPass : Pass<"enzyme-batch-to-stablehlo"> {
let summary = "Legalize batching specific enzyme ops to stablehlo dialect";
let dependentDialects = [
"stablehlo::StablehloDialect",
"enzyme::EnzymeDialect"
];
}

#endif
71 changes: 71 additions & 0 deletions test/lit_tests/OptimizeAD/bwd_batch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch %s | FileCheck %s
// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch --enzyme-batch-to-tensor %s | FileCheck %s --check-prefix=LEGAL

//1. Scalar test
module {
func.func @square(%x : tensor<f64>) -> tensor<f64>{
%y = stablehlo.multiply %x, %x : tensor<f64>
return %y : tensor<f64>
}
func.func @test1(%x : tensor<f64>, %dr1 : tensor<f64>, %dr2 : tensor<f64>) -> (tensor<f64>,tensor<f64>) {
%r, %dx1 = enzyme.autodiff @square(%x, %dr1) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>] } : (tensor<f64>, tensor<f64>) -> (tensor<f64>,tensor<f64>)
%r2, %dx2 = enzyme.autodiff @square(%x, %dr2) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>] } : (tensor<f64>, tensor<f64>) -> (tensor<f64>,tensor<f64>)
return %dx1,%dx2 : tensor<f64>, tensor<f64>
}
}

// CHECK-LABEL: func.func @test1
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<f64>, tensor<f64>) -> tensor<2xf64>
// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
// CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[0] : (tensor<2xf64>) -> tensor<f64>
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[1] : (tensor<2xf64>) -> tensor<f64>
// CHECK-NEXT: return %[[RES0]], %[[RES1]]

// LEGAL-LABEL: func.func @test1
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<f64>) -> tensor<1xf64>
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<f64>) -> tensor<1xf64>
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1xf64>, tensor<1xf64>) -> tensor<2xf64>
// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1xf64>) -> tensor<f64>
// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64>
// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1xf64>) -> tensor<f64>
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]

// -----

//2. Tensor test
module {
func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{
%y = stablehlo.multiply %x, %x : tensor<10xf64>
return %y : tensor<10xf64>
}
func.func @test2(%x : tensor<10xf64>, %dr1 : tensor<10xf64>, %dr2 : tensor<10xf64>) -> (tensor<10xf64>,tensor<10xf64>) {
%r, %dx1 = enzyme.autodiff @square(%x, %dr1) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
%r2, %dx2 = enzyme.autodiff @square(%x, %dr2) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
return %dx1,%dx2 : tensor<10xf64>,tensor<10xf64>
}
}


// CHECK-LABEL: func.func @test2
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64>
// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>)
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[0] : (tensor<2x10xf64>) -> tensor<10xf64>
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[1] : (tensor<2x10xf64>) -> tensor<10xf64>
// CHECK-NEXT: return %[[RES0]], %[[RES1]]

// LEGAL-LABEL: func.func @test2
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<10xf64>) -> tensor<1x10xf64>
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<10xf64>) -> tensor<1x10xf64>
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>)
// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [0:1, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1x10xf64>) -> tensor<10xf64>
// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [1:2, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1x10xf64>) -> tensor<10xf64>
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]
70 changes: 70 additions & 0 deletions test/lit_tests/OptimizeAD/fwd_batch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch %s | FileCheck %s
// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch --enzyme-batch-to-tensor %s | FileCheck %s --check-prefix=LEGAL
// 1. Scalar test
module {
func.func @square(%x : tensor<f64>) -> tensor<f64>{
%y = stablehlo.multiply %x, %x : tensor<f64>
return %y : tensor<f64>
}
func.func @test1(%x : tensor<f64>, %dx1 : tensor<f64>, %dx2 : tensor<f64>) -> (tensor<f64>,tensor<f64>) {
%r1 = enzyme.fwddiff @square(%x, %dx1) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>]} : (tensor<f64>, tensor<f64>) -> (tensor<f64>)
%r2 = enzyme.fwddiff @square(%x, %dx2) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (tensor<f64>, tensor<f64>) -> (tensor<f64>)
return %r1,%r2 : tensor<f64>, tensor<f64>
}
}

// CHECK-LABEL: func.func @test1
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<f64>, tensor<f64>) -> tensor<2xf64>
// CHECK: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> tensor<2xf64>
// CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][0] : (tensor<2xf64>) -> tensor<f64>
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][1] : (tensor<2xf64>) -> tensor<f64>
// CHECK-NEXT: return %[[RES0]], %[[RES1]]

// LEGAL-LABEL: func.func @test1
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<f64>) -> tensor<1xf64>
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<f64>) -> tensor<1xf64>
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1xf64>, tensor<1xf64>) -> tensor<2xf64>
// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> tensor<2xf64>
// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES]] [0:1] : (tensor<2xf64>) -> tensor<1xf64>
// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1xf64>) -> tensor<f64>
// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES]] [1:2] : (tensor<2xf64>) -> tensor<1xf64>
// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1xf64>) -> tensor<f64>
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]

// -----

// 2. Tensor test
module {
func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{
%y = stablehlo.multiply %x, %x : tensor<10xf64>
return %y : tensor<10xf64>
}
func.func @test2(%x : tensor<10xf64>, %dx : tensor<10xf64>, %dx2 : tensor<10xf64>) -> (tensor<10xf64>,tensor<10xf64>) {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>)
%r2 = enzyme.fwddiff @square(%x, %dx2) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>)
return %r,%r2 : tensor<10xf64>,tensor<10xf64>
}
}


// CHECK-LABEL: func.func @test2
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64>
// CHECK: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
// CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][0] : (tensor<2x10xf64>) -> tensor<10xf64>
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][1] : (tensor<2x10xf64>) -> tensor<10xf64>
// CHECK-NEXT: return %[[RES0]], %[[RES1]]

// LEGAL-LABEL: func.func @test2
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<10xf64>) -> tensor<1x10xf64>
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<10xf64>) -> tensor<1x10xf64>
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES]] [0:1, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1x10xf64>) -> tensor<10xf64>
// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES]] [1:2, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1x10xf64>) -> tensor<10xf64>
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]
Loading