Skip to content

Commit 1b1b510

Browse files
authored
feat: scatter op batch interface (#2072)
1 parent 52d94b6 commit 1b1b510

File tree

5 files changed

+141
-12
lines changed

5 files changed

+141
-12
lines changed

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -467,10 +467,10 @@ class AutoDiffWhileFwd
467467
auto idx = arg.getArgNumber();
468468
if (resultPositionsToShadow.count(idx)) {
469469
if (gutils->isConstantValue(arg)) {
470-
nb->insertArgument(
471-
curidx,
472-
cast<AutoDiffTypeInterface>(arg.getType()).getShadowType(gutils->width),
473-
op.getLoc());
470+
nb->insertArgument(curidx,
471+
cast<AutoDiffTypeInterface>(arg.getType())
472+
.getShadowType(gutils->width),
473+
op.getLoc());
474474
}
475475
curidx++;
476476
}
@@ -4124,6 +4124,81 @@ struct SHLOConvolutionOpBatchInterface
41244124
}
41254125
};
41264126

4127+
struct SHLOScatterOpBatchInterface
4128+
: public BatchOpInterface::ExternalModel<SHLOScatterOpBatchInterface,
4129+
stablehlo::ScatterOp> {
4130+
mlir::LogicalResult createBatch(Operation *src, OpBuilder &builder,
4131+
IRMapping &mapper,
4132+
ArrayRef<int64_t> batchSizes) const {
4133+
auto op = cast<stablehlo::ScatterOp>(src);
4134+
4135+
SmallVector<Value> newInputs;
4136+
newInputs.reserve(op.getInputs().size());
4137+
for (auto input : op.getInputs()) {
4138+
newInputs.push_back(mapper.lookup(input));
4139+
}
4140+
4141+
auto newScatterIndices = mapper.lookup(op.getScatterIndices());
4142+
4143+
SmallVector<Value> newUpdates;
4144+
newUpdates.reserve(op.getUpdates().size());
4145+
for (auto update : op.getUpdates()) {
4146+
newUpdates.push_back(mapper.lookup(update));
4147+
}
4148+
4149+
auto dimNumbers = op.getScatterDimensionNumbers();
4150+
int64_t nBatch = batchSizes.size();
4151+
4152+
SmallVector<int64_t> newUpdateWindowDims;
4153+
for (auto dim : dimNumbers.getUpdateWindowDims()) {
4154+
newUpdateWindowDims.push_back(dim + nBatch);
4155+
}
4156+
4157+
SmallVector<int64_t> newInsertedWindowDims;
4158+
for (auto dim : dimNumbers.getInsertedWindowDims()) {
4159+
newInsertedWindowDims.push_back(dim + nBatch);
4160+
}
4161+
4162+
SmallVector<int64_t> newInputBatchingDims, newScatterIndicesBatchingDims;
4163+
for (int64_t i = 0; i < nBatch; ++i) {
4164+
newInputBatchingDims.push_back(i);
4165+
newScatterIndicesBatchingDims.push_back(i);
4166+
}
4167+
for (auto dim : dimNumbers.getInputBatchingDims()) {
4168+
newInputBatchingDims.push_back(dim + nBatch);
4169+
}
4170+
for (auto dim : dimNumbers.getScatterIndicesBatchingDims()) {
4171+
newScatterIndicesBatchingDims.push_back(dim + nBatch);
4172+
}
4173+
4174+
SmallVector<int64_t> newScatterDimsToOperandDims;
4175+
for (auto dim : dimNumbers.getScatterDimsToOperandDims()) {
4176+
newScatterDimsToOperandDims.push_back(dim + nBatch);
4177+
}
4178+
4179+
auto newIndexVectorDim = dimNumbers.getIndexVectorDim() + nBatch;
4180+
4181+
auto newDimNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
4182+
builder.getContext(), newUpdateWindowDims, newInsertedWindowDims,
4183+
newInputBatchingDims, newScatterIndicesBatchingDims,
4184+
newScatterDimsToOperandDims, newIndexVectorDim);
4185+
4186+
auto newScatterOp = stablehlo::ScatterOp::create(
4187+
builder, op.getLoc(), newInputs, newScatterIndices, newUpdates,
4188+
newDimNumbers, op.getIndicesAreSortedAttr(), op.getUniqueIndicesAttr());
4189+
4190+
IRMapping regionMapper;
4191+
op.getUpdateComputation().cloneInto(&newScatterOp.getUpdateComputation(),
4192+
regionMapper);
4193+
4194+
for (int i = 0; i < op.getNumResults(); ++i) {
4195+
mapper.map(op.getResult(i), newScatterOp.getResult(i));
4196+
}
4197+
4198+
return success();
4199+
}
4200+
};
4201+
41274202
struct StablehloAddSimplifyMathInterface
41284203
: public MathSimplifyInterface::ExternalModel<
41294204
StablehloAddSimplifyMathInterface, stablehlo::AddOp> {
@@ -4239,8 +4314,7 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
42394314
ConvolutionOp::attachInterface<SHLOConvolutionOpBatchInterface>(*context);
42404315
PadOp::attachInterface<SHLOPadOpBatchInterface>(*context);
42414316

4242-
ScatterOp::attachInterface<SHLOGenericBatchOpInterface<ScatterOp>>(
4243-
*context); // TODO: simpler version with newly named dims
4317+
ScatterOp::attachInterface<SHLOScatterOpBatchInterface>(*context);
42444318

42454319
AddOp::attachInterface<StablehloAddSimplifyMathInterface>(*context);
42464320
SubtractOp::attachInterface<StablehloSubSimplifyMathInterface>(*context);

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -938,9 +938,8 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
938938
.Case<stablehlo::ReshapeOp, stablehlo::SliceOp, stablehlo::ReturnOp,
939939
// avoid ops that use SHLOGenericBatchOpInterface since that
940940
// lowers to loop
941-
stablehlo::ScatterOp, stablehlo::IfOp, stablehlo::CaseOp,
942-
stablehlo::WhileOp, stablehlo::CustomCallOp>(
943-
[](auto op) { return true; })
941+
stablehlo::IfOp, stablehlo::CaseOp, stablehlo::WhileOp,
942+
stablehlo::CustomCallOp>([](auto op) { return true; })
944943
.Case<stablehlo::BroadcastInDimOp, stablehlo::TransposeOp>(
945944
[](auto op) { return stablehlo::OpIsReshapeLike(op); })
946945
.Default([](auto op) { return false; });
@@ -2327,6 +2326,7 @@ void populateAutoBatchingPassPatterns(RewritePatternSet &patterns,
23272326
SliceToBatch<stablehlo::GetDimensionSizeOp>,
23282327
SliceToBatch<stablehlo::ReverseOp>,
23292328
SliceToBatch<stablehlo::ConvolutionOp>,
2329+
SliceToBatch<stablehlo::ScatterOp>,
23302330
SliceToBatchWithReshapeLikeCheck<stablehlo::BroadcastInDimOp>,
23312331
SliceToBatchWithReshapeLikeCheck<stablehlo::TransposeOp>,
23322332
SliceToBatchElementwise>(ctx);
@@ -2338,8 +2338,7 @@ void populateAutoBatchingPassPatterns(RewritePatternSet &patterns,
23382338
ConcatInsertDimToBatch<stablehlo::IotaOp>,
23392339
ConcatInsertDimToBatchReduceLike<stablehlo::ReduceOp>,
23402340
ConcatInsertDimToBatchReduceLike<stablehlo::ReduceWindowOp>,
2341-
// ConcatInsertDimToBatch<stablehlo::ScatterOp>, after batch
2342-
// op interface is implemented
2341+
ConcatInsertDimToBatch<stablehlo::ScatterOp>,
23432342
ConcatInsertDimToBatch<stablehlo::SortOp>,
23442343
ConcatInsertDimToBatch<stablehlo::ConcatenateOp>,
23452344
ConcatInsertDimToBatch<stablehlo::GetDimensionSizeOp>,

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2744,6 +2744,11 @@ def ApplyConcatInsertDimConvolutionPatterns : EnzymeHLOPatternOp<
27442744
let patterns = ["ConcatInsertDimToBatch<stablehlo::ConvolutionOp>"];
27452745
}
27462746

2747+
def ApplyConcatInsertDimScatterPatterns : EnzymeHLOPatternOp<
2748+
"concat_insert_dim_scatter"> {
2749+
let patterns = ["ConcatInsertDimToBatch<stablehlo::ScatterOp>"];
2750+
}
2751+
27472752
def ApplyConcatInsertDimElementwisePatterns : EnzymeHLOPatternOp<
27482753
"concat_insert_dim_elementwise"> {
27492754
let patterns = ["ConcatInsertDimElementwiseToBatch"];
@@ -2769,6 +2774,10 @@ def ApplySortSliceToBatchPatterns : EnzymeHLOPatternOp<
27692774
"sort_slice_to_batch"> {
27702775
let patterns = ["SliceToBatch<stablehlo::SortOp>"];
27712776
}
2777+
def ApplyScatterSliceToBatchPatterns : EnzymeHLOPatternOp<
2778+
"scatter_slice_to_batch"> {
2779+
let patterns = ["SliceToBatch<stablehlo::ScatterOp>"];
2780+
}
27722781
def ApplyTransposeSliceToBatchPatterns : EnzymeHLOPatternOp<
27732782
"transpose_slice_to_batch"> {
27742783
let patterns = ["SliceToBatchWithReshapeLikeCheck<stablehlo::TransposeOp>"];

src/enzyme_ad/jax/primitives.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def optimization_passes(
395395
"sub_const_prop",
396396
"xor_const_prop",
397397
# other constant propagations
398-
"const_prop_through_barrier<16>",
398+
# "const_prop_through_barrier<16>",
399399
f"concat_const_prop<1>({max_constant_threshold})",
400400
f"dynamic_update_slice_const_prop({max_constant_threshold})",
401401
"clamp_const_prop",
@@ -418,6 +418,7 @@ def optimization_passes(
418418
"reducewindow_slice_to_batch",
419419
"elementwise_slice_to_batch",
420420
"convolution_slice_to_batch",
421+
"scatter_slice_to_batch",
421422
]
422423

423424
if enable_concat_to_batch_passes:
@@ -430,6 +431,7 @@ def optimization_passes(
430431
"concat_insert_dim_reduce_window",
431432
"concat_insert_dim_elementwise",
432433
"concat_insert_dim_convolution",
434+
"concat_insert_dim_scatter",
433435
]
434436

435437
if enable_reduce_slice_fusion_passes:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-batch | FileCheck %s
2+
3+
module {
4+
func.func private @unbatched_scatter(%arg0: tensor<8xf32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2xf32>) -> tensor<8xf32> {
5+
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{
6+
indices_are_sorted = false,
7+
scatter_dimension_numbers = #stablehlo.scatter<
8+
update_window_dims = [],
9+
inserted_window_dims = [0],
10+
scatter_dims_to_operand_dims = [0],
11+
index_vector_dim = 1
12+
>,
13+
unique_indices = false
14+
}> ({
15+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
16+
%1 = stablehlo.add %arg3, %arg4 : tensor<f32>
17+
stablehlo.return %1 : tensor<f32>
18+
}) : (tensor<8xf32>, tensor<2x1xi32>, tensor<2xf32>) -> tensor<8xf32>
19+
return %0 : tensor<8xf32>
20+
}
21+
22+
func.func @main(%arg0: tensor<4x8xf32>, %arg1: tensor<4x2x1xi32>, %arg2: tensor<4x2xf32>) -> tensor<4x8xf32> {
23+
%0 = enzyme.batch @unbatched_scatter(%arg0, %arg1, %arg2) {batch_shape = array<i64: 4>} : (tensor<4x8xf32>, tensor<4x2x1xi32>, tensor<4x2xf32>) -> tensor<4x8xf32>
24+
return %0 : tensor<4x8xf32>
25+
}
26+
}
27+
28+
// CHECK: func.func private @batched_unbatched_scatter(%arg0: tensor<4x8xf32>, %arg1: tensor<4x2x1xi32>, %arg2: tensor<4x2xf32>) -> tensor<4x8xf32> {
29+
// CHECK-NEXT: %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{
30+
// CHECK-SAME: indices_are_sorted = false,
31+
// CHECK-SAME: scatter_dimension_numbers = #stablehlo.scatter<
32+
// CHECK-SAME: inserted_window_dims = [1],
33+
// CHECK-SAME: input_batching_dims = [0],
34+
// CHECK-SAME: scatter_indices_batching_dims = [0],
35+
// CHECK-SAME: scatter_dims_to_operand_dims = [1],
36+
// CHECK-SAME: index_vector_dim = 2
37+
// CHECK-SAME: >,
38+
// CHECK-SAME: unique_indices = false
39+
// CHECK-SAME: }> ({
40+
// CHECK-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
41+
// CHECK-NEXT: %1 = stablehlo.add %arg3, %arg4 : tensor<f32>
42+
// CHECK-NEXT: stablehlo.return %1 : tensor<f32>
43+
// CHECK-NEXT: }) : (tensor<4x8xf32>, tensor<4x2x1xi32>, tensor<4x2xf32>) -> tensor<4x8xf32>
44+
// CHECK-NEXT: return %0 : tensor<4x8xf32>
45+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)