Skip to content

Commit a9db3a2

Browse files
committed
Finalize changes
1 parent 1406731 commit a9db3a2

File tree

3 files changed

+43
-55
lines changed

3 files changed

+43
-55
lines changed

src/enzyme_ad/jax/Passes/EnzymeBatchToStableHLOPass.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,26 +47,23 @@ struct ExtractOpConversion : public OpConversionPattern<enzyme::ExtractOp> {
4747
if (ndims < 1)
4848
return failure();
4949

50-
// dynamic_slice followed by reshape
51-
auto i64Ty = IntegerType::get(rewriter.getContext(), 64);
52-
auto tensor0i64Ty = RankedTensorType::get({}, i64Ty);
53-
auto zero = rewriter.create<stablehlo::ConstantOp>(
54-
op.getLoc(), rewriter.getZeroAttr(tensor0i64Ty));
55-
56-
SmallVector<Value> dynamicSliceStartSlices(ndims, zero);
57-
dynamicSliceStartSlices[0] = op.getIndex(); // assume its legal for no
58-
59-
SmallVector<int64_t> localRetShape = {1};
60-
localRetShape.append(outRankTy.getShape().begin(),
50+
// static slice
51+
SmallVector<int64_t> start_indices;
52+
start_indices.push_back(op.getIndex());
53+
for (int i = 1; i < ndims; ++i) {
54+
start_indices.push_back(0);
55+
}
56+
SmallVector<int64_t> limit_indices;
57+
limit_indices.push_back(op.getIndex() + 1);
58+
limit_indices.append(outRankTy.getShape().begin(),
6159
outRankTy.getShape().end());
62-
;
63-
auto slicedOut = rewriter.create<stablehlo::DynamicSliceOp>(
64-
op->getLoc(), op.getInput(), dynamicSliceStartSlices, localRetShape);
60+
SmallVector<int64_t> strides(ndims, 1);
6561

62+
Value slicedOut =
63+
stablehlo::SliceOp::create(rewriter, op->getLoc(), op.getInput(),
64+
start_indices, limit_indices, strides);
6665
// reshape slicedOut to our final Op
67-
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, op->getLoc(), outTy,
68-
slicedOut);
69-
66+
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(op, outTy, slicedOut);
7067
return success();
7168
}
7269
};
@@ -81,8 +78,6 @@ struct ConcatOpConversion : public OpConversionPattern<enzyme::ConcatOp> {
8178
if (inputs.empty())
8279
return failure();
8380

84-
auto firstInTy = inputs.front().getType();
85-
8681
// stablehlo always has tensor type
8782
// reshape each input to 1xinput_rank and concatenate on dim=0
8883

@@ -94,7 +89,7 @@ struct ConcatOpConversion : public OpConversionPattern<enzyme::ConcatOp> {
9489
newInShape.append(inShape.begin(), inShape.end());
9590
auto newInTy = inRankTy.clone(newInShape);
9691
Value newInput =
97-
rewriter.create<stablehlo::ReshapeOp>(op->getLoc(), newInTy, in);
92+
stablehlo::ReshapeOp::create(rewriter, op->getLoc(), newInTy, in);
9893
expandedInputs.push_back(newInput);
9994
}
10095

@@ -116,7 +111,7 @@ struct EnzymeBatchToStableHLOPass
116111
ConversionTarget target(*context);
117112
target.addLegalDialect<stablehlo::StablehloDialect>();
118113
target.addLegalDialect<enzyme::EnzymeDialect>();
119-
target.addIllegalOp<enzyme::ConcatOp>();
114+
target.addIllegalOp<enzyme::ConcatOp, enzyme::ExtractOp>();
120115

121116
if (failed(applyPartialConversion(getOperation(), target,
122117
std::move(patterns)))) {

test/lit_tests/adbatching/bwd_batch.mlir renamed to test/lit_tests/OptimizeAD/bwd_batch.mlir

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,20 @@ module {
1818
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
1919
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<f64>, tensor<f64>) -> tensor<2xf64>
2020
// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
21-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
22-
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : (tensor<2xf64>) -> tensor<f64>
23-
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
24-
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : (tensor<2xf64>) -> tensor<f64>
21+
// CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[0] : (tensor<2xf64>) -> tensor<f64>
22+
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[1] : (tensor<2xf64>) -> tensor<f64>
2523
// CHECK-NEXT: return %[[RES0]], %[[RES1]]
2624

2725
// LEGAL-LABEL: func.func @test1
28-
// LEGAL-SAME: (%[[PRIMAL:.*]]: f64, %[[DIFF1:.*]]: f64, %[[DIFF2:.*]]: f64) -> (f64, f64)
26+
// LEGAL-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
2927
// LEGAL: %[[EDIFF1:.*]] = stablehlo.reshape %[[DIFF1]] : (tensor<f64>) -> tensor<1xf64>
3028
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<f64>) -> tensor<1xf64>
3129
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1xf64>, tensor<1xf64>) -> tensor<2xf64>
32-
// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (f64, tensor<2xf64>) -> (f64, tensor<2xf64>)
33-
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
34-
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : tensor<2xf64>
35-
// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index
36-
// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : tensor<2xf64>
30+
// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
31+
// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
32+
// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1xf64>) -> tensor<f64>
33+
// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64>
34+
// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1xf64>) -> tensor<f64>
3735
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]
3836

3937
// -----
@@ -56,10 +54,8 @@ module {
5654
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
5755
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64>
5856
// CHECK: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>)
59-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
60-
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C0]]] : (tensor<2x10xf64>) -> tensor<10xf64>
61-
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
62-
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[%[[C1]]] : (tensor<2x10xf64>) -> tensor<10xf64>
57+
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[0] : (tensor<2x10xf64>) -> tensor<10xf64>
58+
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES_BASE]]#1[1] : (tensor<2x10xf64>) -> tensor<10xf64>
6359
// CHECK-NEXT: return %[[RES0]], %[[RES1]]
6460

6561
// LEGAL-LABEL: func.func @test2
@@ -68,8 +64,8 @@ module {
6864
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<10xf64>) -> tensor<1x10xf64>
6965
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
7066
// LEGAL: %[[BATCHED_RES_BASE:.*]]:2 = enzyme.autodiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<10xf64>, tensor<2x10xf64>)
71-
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
72-
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract_slice %[[BATCHED_RES_BASE]]#1[%[[C0]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
73-
// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index
74-
// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract_slice %[[BATCHED_RES_BASE]]#1[%[[C1]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
67+
// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [0:1, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
68+
// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1x10xf64>) -> tensor<10xf64>
69+
// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES_BASE]]#1 [1:2, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
70+
// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1x10xf64>) -> tensor<10xf64>
7571
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]

test/lit_tests/adbatching/fwd_batch.mlir renamed to test/lit_tests/OptimizeAD/fwd_batch.mlir

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@ module {
1717
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<f64>, %[[DIFF1:.*]]: tensor<f64>, %[[DIFF2:.*]]: tensor<f64>) -> (tensor<f64>, tensor<f64>)
1818
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<f64>, tensor<f64>) -> tensor<2xf64>
1919
// CHECK: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> tensor<2xf64>
20-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
21-
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][%[[C0]]] : (tensor<2xf64>) -> tensor<f64>
22-
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
23-
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][%[[C1]]] : (tensor<2xf64>) -> tensor<f64>
20+
// CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][0] : (tensor<2xf64>) -> tensor<f64>
21+
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][1] : (tensor<2xf64>) -> tensor<f64>
2422
// CHECK-NEXT: return %[[RES0]], %[[RES1]]
2523

2624
// LEGAL-LABEL: func.func @test1
@@ -29,10 +27,10 @@ module {
2927
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<f64>) -> tensor<1xf64>
3028
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1xf64>, tensor<1xf64>) -> tensor<2xf64>
3129
// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<f64>, tensor<2xf64>) -> tensor<2xf64>
32-
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
33-
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract %[[BATCHED_RES]][%[[C0]]] : tensor<2xf64>
34-
// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index
35-
// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract %[[BATCHED_RES]][%[[C1]]] : tensor<2xf64>
30+
// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES]] [0:1] : (tensor<2xf64>) -> tensor<1xf64>
31+
// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1xf64>) -> tensor<f64>
32+
// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES]] [1:2] : (tensor<2xf64>) -> tensor<1xf64>
33+
// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1xf64>) -> tensor<f64>
3634
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]
3735

3836
// -----
@@ -50,14 +48,13 @@ module {
5048
}
5149
}
5250

51+
5352
// CHECK-LABEL: func.func @test2
5453
// CHECK-SAME: (%[[PRIMAL:.*]]: tensor<10xf64>, %[[DIFF1:.*]]: tensor<10xf64>, %[[DIFF2:.*]]: tensor<10xf64>) -> (tensor<10xf64>, tensor<10xf64>)
5554
// CHECK: %[[CONCAT:.*]] = enzyme.concat(%[[DIFF1]], %[[DIFF2]]) : (tensor<10xf64>, tensor<10xf64>) -> tensor<2x10xf64>
5655
// CHECK: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
57-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
58-
// CHECK-NEXT: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][%[[C0]]] : (tensor<2x10xf64>) -> tensor<10xf64>
59-
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
60-
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][%[[C1]]] : (tensor<2x10xf64>) -> tensor<10xf64>
56+
// CHECK: %[[RES0:.*]] = enzyme.extract %[[BATCHED_RES]][0] : (tensor<2x10xf64>) -> tensor<10xf64>
57+
// CHECK-NEXT: %[[RES1:.*]] = enzyme.extract %[[BATCHED_RES]][1] : (tensor<2x10xf64>) -> tensor<10xf64>
6158
// CHECK-NEXT: return %[[RES0]], %[[RES1]]
6259

6360
// LEGAL-LABEL: func.func @test2
@@ -66,8 +63,8 @@ module {
6663
// LEGAL: %[[EDIFF2:.*]] = stablehlo.reshape %[[DIFF2]] : (tensor<10xf64>) -> tensor<1x10xf64>
6764
// LEGAL: %[[CONCAT:.*]] = stablehlo.concatenate %[[EDIFF1]], %[[EDIFF2]], dim = 0 : (tensor<1x10xf64>, tensor<1x10xf64>) -> tensor<2x10xf64>
6865
// LEGAL: %[[BATCHED_RES:.*]] = enzyme.fwddiff @square(%[[PRIMAL]], %[[CONCAT]]) {{.*}} width = 2 {{.*}} : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64>
69-
// LEGAL: %[[C0:.*]] = arith.constant 0 : index
70-
// LEGAL-NEXT: %[[RES0:.*]] = tensor.extract_slice %[[BATCHED_RES]][%[[C0]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
71-
// LEGAL-NEXT: %[[C1:.*]] = arith.constant 1 : index
72-
// LEGAL-NEXT: %[[RES1:.*]] = tensor.extract_slice %[[BATCHED_RES]][%[[C1]], 0] [1, 10] [1, 1] : tensor<2x10xf64> to tensor<10xf64>
66+
// LEGAL: %[[R0:.*]] = stablehlo.slice %[[BATCHED_RES]] [0:1, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
67+
// LEGAL-NEXT: %[[RES0:.*]] = stablehlo.reshape %[[R0]] : (tensor<1x10xf64>) -> tensor<10xf64>
68+
// LEGAL-NEXT: %[[R1:.*]] = stablehlo.slice %[[BATCHED_RES]] [1:2, 0:10] : (tensor<2x10xf64>) -> tensor<1x10xf64>
69+
// LEGAL-NEXT: %[[RES1:.*]] = stablehlo.reshape %[[R1]] : (tensor<1x10xf64>) -> tensor<10xf64>
7370
// LEGAL-NEXT: return %[[RES0]], %[[RES1]]

0 commit comments

Comments
 (0)