Skip to content

Commit bb22ea2

Browse files
committed
Add ConcatOp lowering, fix lit test
Need to find another cafe to work...
1 parent d7c66f3 commit bb22ea2

File tree

3 files changed

+198
-3
lines changed

3 files changed

+198
-3
lines changed

src/enzyme_ad/jax/Passes/EnzymeBatchToStableHLOPass.cpp

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
// This file implements a pass to print the MLIR module
1010
//===----------------------------------------------------------------------===//
1111

12-
#include "src/enzyme_ad/jax/Passes/Passes.h"
13-
#include "stablehlo/dialect/StablehloOps.h"
1412
#include "Enzyme/MLIR/Dialect/Dialect.h"
1513
#include "Enzyme/MLIR/Dialect/Ops.h"
14+
#include "src/enzyme_ad/jax/Passes/Passes.h"
1615
#include "src/enzyme_ad/jax/Utils.h"
16+
#include "stablehlo/dialect/StablehloOps.h"
1717

1818
#include "mlir/IR/PatternMatch.h"
1919
#include "mlir/Transforms/DialectConversion.h"
@@ -29,16 +29,67 @@ using namespace mlir;
2929
using namespace mlir::enzyme;
3030
using namespace enzyme;
3131
namespace {
32+
33+
struct ExtractOpConversion : public OpConversionPattern<enzyme::ExtractOp> {
34+
using OpConversionPattern<enzyme::ExtractOp>::OpConversionPattern;
35+
36+
LogicalResult
37+
matchAndRewrite(enzyme::ExtractOp op, OpAdaptor adaptor,
38+
ConversionPatternRewriter &rewriter) const override {
39+
auto outTy = op.getOutput().getType();
40+
// stablehlo always has tensor type
41+
auto outRankTy = dyn_cast<RankedTensorType>(outTy);
42+
auto rank = outRankTy.getRank();
43+
return failure();
44+
// stablehlo.dynamic_slice op
45+
}
46+
};
47+
48+
struct ConcatOpConversion : public OpConversionPattern<enzyme::ConcatOp> {
49+
using OpConversionPattern<enzyme::ConcatOp>::OpConversionPattern;
50+
51+
LogicalResult
52+
matchAndRewrite(enzyme::ConcatOp op, OpAdaptor adaptor,
53+
ConversionPatternRewriter &rewriter) const override {
54+
SmallVector<Value> inputs = op.getInputs();
55+
if (inputs.empty())
56+
return failure();
57+
58+
auto firstInTy = inputs.front().getType();
59+
60+
// stablehlo always has tensor type
61+
// reshape each input to 1xinput_rank and concatenate on dim=0
62+
63+
SmallVector<Value> expandedInputs;
64+
for (Value in : inputs) {
65+
auto inRankTy = cast<RankedTensorType>(in.getType());
66+
auto inShape = inRankTy.getShape();
67+
SmallVector<int64_t> newInShape = {1};
68+
newInShape.append(inShape.begin(), inShape.end());
69+
auto newInTy = inRankTy.clone(newInShape);
70+
Value newInput = rewriter.create<stablehlo::ReshapeOp>(
71+
op->getLoc(), newInTy, in, op->getAttrs());
72+
expandedInputs.push_back(newInput);
73+
}
74+
75+
// concatenate on dim=0
76+
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
77+
op, op->getResultTypes(), expandedInputs, /*dim=*/0);
78+
return success();
79+
}
80+
};
3281
struct EnzymeBatchToStableHLOPass
3382
: public enzyme::impl::EnzymeBatchToStableHLOPassBase<
3483
EnzymeBatchToStableHLOPass> {
3584
void runOnOperation() override {
3685
MLIRContext *context = &getContext();
3786
RewritePatternSet patterns(context);
87+
patterns.add<ConcatOpConversion, ExtractOpConversion>(context);
88+
3889
ConversionTarget target(*context);
3990
target.addLegalDialect<stablehlo::StablehloDialect>();
4091
target.addLegalDialect<enzyme::EnzymeDialect>();
41-
target.addIllegalOp<enzyme::ConcatOp, enzyme::ExtractOp>();
92+
target.addIllegalOp<enzyme::ConcatOp>();
4293

4394
if (failed(applyPartialConversion(getOperation(), target,
4495
std::move(patterns)))) {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch %s | FileCheck %s
2+
// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch --enzyme-batch-to-tensor %s | FileCheck %s --check-prefix=LEGAL
3+
4+
//1. Scalar test
5+
module {
6+
func.func @square(%x : tensor<f64>) -> tensor<f64>{
7+
%y = stablehlo.multiply %x, %x : tensor<f64>
8+
return %y : tensor<f64>
9+
}
10+
func.func @test1(%x : tensor<f64>, %dr1 : tensor<f64>, %dr2 : tensor<f64>) -> (tensor<f64>,tensor<f64>) {
11+
%r, %dx1 = enzyme.autodiff @square(%x, %dr1) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>] } : (tensor<f64>, tensor<f64>) -> (tensor<f64>,tensor<f64>)
12+
%r2, %dx2 = enzyme.autodiff @square(%x, %dr2) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>] } : (tensor<f64>, tensor<f64>) -> (tensor<f64>,tensor<f64>)
13+
return %dx1,%dx2 : tensor<f64>, tensor<f64>
14+
}
15+
}
16+
17+
// CHECK-LABEL: func.func @test1
18+
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
19+
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<f64>, tensor<f64>) -> tensor<2xf64>
20+
// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
21+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
22+
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : (tensor<2xf64>) -> tensor<f64>
23+
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
24+
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : (tensor<2xf64>) -> tensor<f64>
25+
// CHECK-NEXT: return %[[RES0]], %[[RES1]]
26+
27+
// LEGAL-LABEL: func.func @test1
28+
// LEGAL-SAME: (%[[PRIMAL:.*]]: f64, %[[DIFF1:.*]]: f64, %[[DIFF2:.*]]: f64) -> (f64, f64)
29+
// LEGAL: %[[CONCAT:.*]] = tensor.from_elements %[[DIFF1]], %[[DIFF2]] : tensor<2xf64>
30+
// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (f64, tensor<2xf64>) -> (f64, tensor<2xf64>)
31+
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
32+
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : tensor<2xf64>
33+
// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index
34+
// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : tensor<2xf64>
35+
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]
36+
37+
// -----
38+
39+
//2. Tensor test
40+
module {
41+
func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{
42+
%y = stablehlo.multiply %x, %x : tensor<10xf64>
43+
return %y : tensor<10xf64>
44+
}
45+
func.func @test2(%x : tensor<10xf64>, %dr1 : tensor<10xf64>, %dr2 : tensor<10xf64>) -> (tensor<10xf64>,tensor<10xf64>) {
46+
%r, %dx1 = enzyme.autodiff @square(%x, %dr1) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
47+
%r2, %dx2 = enzyme.autodiff @square(%x, %dr2) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
48+
return %dx1,%dx2 : tensor<10xf64>,tensor<10xf64>
49+
}
50+
}
51+
52+
53+
// CHECK-LABEL: func.func @test2
54+
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
55+
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64>
56+
// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>)
57+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
58+
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : (tensor<2x10xf64>) -> tensor<10xf64>
59+
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
60+
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : (tensor<2x10xf64>) -> tensor<10xf64>
61+
// CHECK-NEXT: return %[[RES0]], %[[RES1]]
62+
63+
// LEGAL-LABEL: func.func @test2
64+
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
65+
// LEGAL: %[[EDIFF1:.*]] = tensor.expand_shape %[[DIFF1]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64>
66+
// LEGAL: %[[EDIFF2:.*]] = tensor.expand_shape %[[DIFF2]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64>
67+
// LEGAL: %[[CONCAT:.*]] = tensor.concat dim(0) %[[EDIFF1]], %[[EDIFF2]] : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
68+
// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>)
69+
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
70+
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract_slice %[[BATCHED_RES_BASE]]#1[%[[C0]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
71+
// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index
72+
// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract_slice %[[BATCHED_RES_BASE]]#1[%[[C1]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
73+
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch %s | FileCheck %s
2+
// RUN: enzymexlamlir-opt --split-input-file --enzyme-diff-batch --enzyme-batch-to-tensor %s | FileCheck %s --check-prefix=LEGAL
3+
// 1. Scalar test
4+
module {
5+
func.func @square(%x : tensor<f64>) -> tensor<f64>{
6+
%y = stablehlo.multiply %x, %x : tensor<f64>
7+
return %y : tensor<f64>
8+
}
9+
func.func @test1(%x : tensor<f64>, %dx1 : tensor<f64>, %dx2 : tensor<f64>) -> (tensor<f64>,tensor<f64>) {
10+
%r1 = enzyme.fwddiff @square(%x, %dx1) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>]} : (tensor<f64>, tensor<f64>) -> (tensor<f64>)
11+
%r2 = enzyme.fwddiff @square(%x, %dx2) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (tensor<f64>, tensor<f64>) -> (tensor<f64>)
12+
return %r1,%r2 : tensor<f64>, tensor<f64>
13+
}
14+
}
15+
16+
// CHECK-LABEL: func.func @test1
17+
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
18+
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<f64>, tensor<f64>) -> tensor<2xf64>
19+
// CHECK: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> tensor<2xf64>
20+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
21+
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][%[[C0]]] : (tensor<2xf64>) -> tensor<f64>
22+
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
23+
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][%[[C1]]] : (tensor<2xf64>) -> tensor<f64>
24+
// CHECK-NEXT: return %[[RES0]], %[[RES1]]
25+
26+
// LEGAL-LABEL: func.func @test1
27+
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
28+
// LEGAL: %[[CONCAT:.*]] = tensor.from_elements %[[DIFF1]], %[[DIFF2]] : tensor<2xf64>
29+
// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> tensor<2xf64>
30+
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
31+
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract %[[BATCHED_RES]][%[[C0]]] : tensor<2xf64>
32+
// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index
33+
// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract %[[BATCHED_RES]][%[[C1]]] : tensor<2xf64>
34+
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]
35+
36+
// -----
37+
38+
// 2. Tensor test
39+
module {
40+
func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{
41+
%y = stablehlo.multiply %x, %x : tensor<10xf64>
42+
return %y : tensor<10xf64>
43+
}
44+
func.func @test2(%x : tensor<10xf64>, %dx : tensor<10xf64>, %dx2 : tensor<10xf64>) -> (tensor<10xf64>,tensor<10xf64>) {
45+
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>]} : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>)
46+
%r2 = enzyme.fwddiff @square(%x, %dx2) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (tensor<10xf64>, tensor<10xf64>) -> (tensor<10xf64>)
47+
return %r,%r2 : tensor<10xf64>,tensor<10xf64>
48+
}
49+
}
50+
51+
// CHECK-LABEL: func.func @test2
52+
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
53+
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64>
54+
// CHECK: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
55+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
56+
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][%[[C0]]] : (tensor<2x10xf64>) -> tensor<10xf64>
57+
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
58+
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][%[[C1]]] : (tensor<2x10xf64>) -> tensor<10xf64>
59+
// CHECK-NEXT: return %[[RES0]], %[[RES1]]
60+
61+
// LEGAL-LABEL: func.func @test2
62+
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
63+
// LEGAL: %[[EDIFF1:.*]] = tensor.expand_shape %[[DIFF1]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64>
64+
// LEGAL: %[[EDIFF2:.*]] = tensor.expand_shape %[[DIFF2]] {{\[\[0, 1\]\]}} output_shape [1, 10] : tensor<10xf64> into tensor<1x10xf64>
65+
// LEGAL: %[[CONCAT:.*]] = tensor.concat dim(0) %[[EDIFF1]], %[[EDIFF2]] : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
66+
// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
67+
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
68+
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract_slice %[[BATCHED_RES]][%[[C0]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
69+
// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index
70+
// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract_slice %[[BATCHED_RES]][%[[C1]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
71+
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]

0 commit comments

Comments
 (0)