Skip to content

Commit 00f0ff5

Browse files
committed
feat: gather of scatter simplify
1 parent 99d2b63 commit 00f0ff5

File tree

4 files changed

+133
-1
lines changed

4 files changed

+133
-1
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28764,6 +28764,53 @@ struct FuseReshapeCollapseOrExpandDimsIntoReduce final
2876428764
}
2876528765
};
2876628766

28767+
struct GatherOfScatterSimplify final
28768+
: CheckedOpRewritePattern<stablehlo::GatherOp, GatherOfScatterSimplify> {
28769+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
28770+
28771+
LogicalResult matchAndRewriteImpl(stablehlo::GatherOp gatherOp,
28772+
PatternRewriter &rewriter) {
28773+
auto input = gatherOp.getOperand();
28774+
auto scatterOp = input.getDefiningOp<stablehlo::ScatterOp>();
28775+
28776+
if (!scatterOp ||
28777+
scatterOp.getScatterIndices() != gatherOp.getStartIndices() ||
28778+
computeGatherSliceSizes(scatterOp) != gatherOp.getSliceSizes() ||
28779+
getGatherDims(scatterOp->getContext(),
28780+
scatterOp.getScatterDimensionNumbersAttr()) !=
28781+
gatherOp.getDimensionNumbersAttr()) {
28782+
return failure();
28783+
}
28784+
28785+
auto opResult = cast<OpResult>(input);
28786+
auto opNum = opResult.getResultNumber();
28787+
28788+
SplatElementsAttr constSetIndexValue;
28789+
if (!detectConstantSetindexScatterOp(
28790+
scatterOp, true, [](auto input) { return true; },
28791+
constSetIndexValue)
28792+
.ok()) {
28793+
return failure();
28794+
}
28795+
28796+
if (constSetIndexValue) {
28797+
auto constResult = stablehlo::ConstantOp::create(
28798+
rewriter, gatherOp.getLoc(),
28799+
constSetIndexValue.resizeSplat(cast<ShapedType>(gatherOp.getType())));
28800+
rewriter.replaceOp(gatherOp, constResult);
28801+
return success();
28802+
}
28803+
28804+
if (!scatterOp.getUniqueIndices()) {
28805+
return failure();
28806+
}
28807+
28808+
auto newResult = scatterOp.getUpdates()[opNum];
28809+
rewriter.replaceOp(gatherOp, newResult);
28810+
return success();
28811+
}
28812+
};
28813+
2876728814
/////////////// End Imported from stablehlo
2876828815

2876928816
// clang-format off
@@ -29477,7 +29524,8 @@ struct EnzymeHLOOptPass
2947729524
DeleteDimsReduce,
2947829525
ReduceDeleteDims,
2947929526
DotGeneralInsertDimContractionSimplification,
29480-
FuseReshapeCollapseOrExpandDimsIntoReduce
29527+
FuseReshapeCollapseOrExpandDimsIntoReduce,
29528+
GatherOfScatterSimplify
2948129529
>(context);
2948229530

2948329531
patterns.add<ReshapeElementwise>(true, true, context);

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2721,3 +2721,8 @@ def ApplyWhileElementwiseReductionToReducePatterns : EnzymeHLOPatternOp<
27212721
"while_elementwise_reduction_to_reduce"> {
27222722
let patterns = ["WhileElementwiseReductionToReduce"];
27232723
}
2724+
2725+
def ApplyGatherOfScatterSimplifyPatterns : EnzymeHLOPatternOp<
2726+
"gather_of_scatter_simplify"> {
2727+
let patterns = ["GatherOfScatterSimplify"];
2728+
}

src/enzyme_ad/jax/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def optimization_passes(
472472
f"scatter_const_fold({max_constant_threshold})",
473473
"cse_gather",
474474
"cse_scatter",
475+
"gather_of_scatter_simplify",
475476
]
476477

477478
if enable_pad_optimization_passes:

test/lit_tests/gather_scatter.mlir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg0: tensor<7x6xf64>) -> (tensor<4x3xf64>, tensor<7x6xf64>) {
5+
%c = stablehlo.constant dense<1> : tensor<3x1xi64>
6+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<3x4xf64>
7+
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
8+
%c_0 = stablehlo.constant dense<[[1], [3], [2]]> : tensor<3x1xi64>
9+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<7x6xf64>) -> tensor<6x7xf64>
10+
%1 = stablehlo.subtract %c_0, %c : tensor<3x1xi64>
11+
%2 = "stablehlo.scatter"(%0, %1, %cst) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>}> ({
12+
^bb0(%arg1: tensor<f64>, %arg2: tensor<f64>):
13+
stablehlo.return %cst_1 : tensor<f64>
14+
}) : (tensor<6x7xf64>, tensor<3x1xi64>, tensor<3x4xf64>) -> tensor<6x7xf64>
15+
%3 = "stablehlo.gather"(%2, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 4>}> : (tensor<6x7xf64>, tensor<3x1xi64>) -> tensor<3x4xf64>
16+
%4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<3x4xf64>) -> tensor<4x3xf64>
17+
%5 = stablehlo.transpose %2, dims = [1, 0] : (tensor<6x7xf64>) -> tensor<7x6xf64>
18+
return %4, %5 : tensor<4x3xf64>, tensor<7x6xf64>
19+
}
20+
}
21+
22+
// CHECK: func.func @main
23+
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%0, %c, %cst_0)
24+
// CHECK-NOT: "stablehlo.gather"
25+
// CHECK: %[[TS:.*]] = stablehlo.transpose %[[SCATTER]]
26+
// CHECK: return %cst, %[[TS]] : tensor<4x3xf64>, tensor<7x6xf64>
27+
28+
module {
29+
func.func @main(%arg0: tensor<7x6xf64>, %arg1: tensor<4x3xf64>) -> (tensor<4x3xf64>, tensor<7x6xf64>) {
30+
%c = stablehlo.constant dense<1> : tensor<3x1xi64>
31+
%c_0 = stablehlo.constant dense<[[1], [3], [2]]> : tensor<3x1xi64>
32+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<7x6xf64>) -> tensor<6x7xf64>
33+
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<4x3xf64>) -> tensor<3x4xf64>
34+
%2 = stablehlo.subtract %c_0, %c : tensor<3x1xi64>
35+
%3 = "stablehlo.scatter"(%0, %2, %1) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
36+
^bb0(%arg2: tensor<f64>, %arg3: tensor<f64>):
37+
stablehlo.return %arg3 : tensor<f64>
38+
}) : (tensor<6x7xf64>, tensor<3x1xi64>, tensor<3x4xf64>) -> tensor<6x7xf64>
39+
%4 = "stablehlo.gather"(%3, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 4>}> : (tensor<6x7xf64>, tensor<3x1xi64>) -> tensor<3x4xf64>
40+
%5 = stablehlo.transpose %4, dims = [1, 0] : (tensor<3x4xf64>) -> tensor<4x3xf64>
41+
%6 = stablehlo.transpose %3, dims = [1, 0] : (tensor<6x7xf64>) -> tensor<7x6xf64>
42+
return %5, %6 : tensor<4x3xf64>, tensor<7x6xf64>
43+
}
44+
}
45+
46+
// CHECK: func.func @main
47+
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%0, %c, %1)
48+
// CHECK-NOT: "stablehlo.gather"
49+
// CHECK: %[[TS:.*]] = stablehlo.transpose %[[SCATTER]]
50+
// CHECK: return %arg1, %[[TS]] : tensor<4x3xf64>, tensor<7x6xf64>
51+
52+
module {
53+
func.func @main(%arg0: tensor<7x6xf64>, %arg1: tensor<4x3xf64>, %arg2: tensor<3xi64>, %arg3: tensor<4xi64>) -> (tensor<4x3xf64>, tensor<7x6xf64>) {
54+
%c = stablehlo.constant dense<1> : tensor<12x2xi64>
55+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<7x6xf64>) -> tensor<6x7xf64>
56+
%1 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<3xi64>) -> tensor<4x3xi64>
57+
%2 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<4xi64>) -> tensor<4x3xi64>
58+
%3 = stablehlo.reshape %2 : (tensor<4x3xi64>) -> tensor<12x1xi64>
59+
%4 = stablehlo.reshape %1 : (tensor<4x3xi64>) -> tensor<12x1xi64>
60+
%5 = stablehlo.concatenate %4, %3, dim = 1 : (tensor<12x1xi64>, tensor<12x1xi64>) -> tensor<12x2xi64>
61+
%6 = stablehlo.reshape %arg1 : (tensor<4x3xf64>) -> tensor<12xf64>
62+
%7 = stablehlo.subtract %5, %c : tensor<12x2xi64>
63+
%8 = "stablehlo.scatter"(%0, %7, %6) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
64+
^bb0(%arg4: tensor<f64>, %arg5: tensor<f64>):
65+
stablehlo.return %arg5 : tensor<f64>
66+
}) : (tensor<6x7xf64>, tensor<12x2xi64>, tensor<12xf64>) -> tensor<6x7xf64>
67+
%9 = "stablehlo.gather"(%8, %7) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<6x7xf64>, tensor<12x2xi64>) -> tensor<12xf64>
68+
%10 = stablehlo.reshape %9 : (tensor<12xf64>) -> tensor<4x3xf64>
69+
%11 = stablehlo.transpose %8, dims = [1, 0] : (tensor<6x7xf64>) -> tensor<7x6xf64>
70+
return %10, %11 : tensor<4x3xf64>, tensor<7x6xf64>
71+
}
72+
}
73+
74+
// CHECK: func.func @main
75+
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%0, %7, %6)
76+
// CHECK-NOT: "stablehlo.gather"
77+
// CHECK: %[[TS:.*]] = stablehlo.transpose %[[SCATTER]]
78+
// CHECK: return %arg1, %[[TS]] : tensor<4x3xf64>, tensor<7x6xf64>

0 commit comments

Comments
 (0)