Skip to content

Commit 9fda595

Browse files
authored
feat: gather elementwise simplify (#1038)
1 parent ee040a2 commit 9fda595

File tree

3 files changed

+78
-1
lines changed

3 files changed

+78
-1
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19683,6 +19683,40 @@ struct UnaryElementwiseScatterSimplify final
1968319683
}
1968419684
};
1968519685

19686+
struct GatherElementwise
19687+
: public CheckedOpRewritePattern<stablehlo::GatherOp, GatherElementwise> {
19688+
using CheckedOpRewritePattern<stablehlo::GatherOp,
19689+
GatherElementwise>::CheckedOpRewritePattern;
19690+
19691+
LogicalResult matchAndRewriteImpl(stablehlo::GatherOp op,
19692+
PatternRewriter &rewriter) const {
19693+
auto gatherInput = op.getOperand();
19694+
auto defOp = gatherInput.getDefiningOp();
19695+
if (!defOp || !defOp->hasTrait<mlir::OpTrait::Elementwise>())
19696+
return rewriter.notifyMatchFailure(op,
19697+
"GatherOp with non-elementwise input");
19698+
19699+
if (!isOnlyUsedInOperation(defOp, op))
19700+
return failure();
19701+
19702+
SmallVector<Value> newElementwiseInputs;
19703+
for (auto input : defOp->getOperands()) {
19704+
auto neeGatherOp = rewriter.create<stablehlo::GatherOp>(
19705+
op.getLoc(), input, op.getStartIndices(),
19706+
op.getDimensionNumbersAttr(), op.getSliceSizesAttr(),
19707+
op.getIndicesAreSortedAttr());
19708+
newElementwiseInputs.push_back(neeGatherOp->getResult(0));
19709+
}
19710+
19711+
auto newElemOp = rewriter.create(
19712+
op.getLoc(), defOp->getName().getIdentifier(),
19713+
ValueRange(newElementwiseInputs), TypeRange{op.getResult().getType()},
19714+
defOp->getAttrs(), {}, {});
19715+
rewriter.replaceOp(op, newElemOp->getResult(0));
19716+
return success();
19717+
}
19718+
};
19719+
1968619720
/////////////// End Imported from stablehlo
1968719721

1968819722
// clang-format off
@@ -20205,7 +20239,8 @@ struct EnzymeHLOOptPass
2020520239
ConjComplexSimplify,
2020620240
SplitConvolutionIntoReverseConvolution,
2020720241
ScatterMultiplySimplify,
20208-
UnaryElementwiseScatterSimplify
20242+
UnaryElementwiseScatterSimplify,
20243+
GatherElementwise
2020920244
>(context);
2021020245

2021120246
patterns.add<SumToReduceWindow<stablehlo::AddOp>,

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,3 +2033,7 @@ def ApplyUnaryElementwiseScatterSimplify : EnzymeHLOPatternOp<"unary_elementwise
20332033
"UnaryElementwiseScatterSimplify"
20342034
];
20352035
}
2036+
2037+
def ApplyGatherElementwise : EnzymeHLOPatternOp<"gather_elementwise"> {
2038+
let patterns = ["GatherElementwise"];
2039+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s
2+
3+
func.func @gather_elementwise1(%arg0: tensor<32x1024xf32>, %arg1: tensor<32x1024xf32>, %arg2: tensor<5xi64>) -> tensor<32x5xf32> {
4+
%c = stablehlo.constant dense<1> : tensor<1x5xi64>
5+
%0 = stablehlo.multiply %arg0, %arg1 : tensor<32x1024xf32>
6+
%1 = stablehlo.reshape %arg2 : (tensor<5xi64>) -> tensor<1x5xi64>
7+
%2 = stablehlo.subtract %1, %c : tensor<1x5xi64>
8+
%3 = "stablehlo.gather"(%0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1]>, indices_are_sorted = false, slice_sizes = array<i64: 32, 1>}> : (tensor<32x1024xf32>, tensor<1x5xi64>) -> tensor<32x5xf32>
9+
return %3 : tensor<32x5xf32>
10+
}
11+
12+
// CHECK: func.func @gather_elementwise1(%arg0: tensor<32x1024xf32>, %arg1: tensor<32x1024xf32>, %arg2: tensor<5xi64>) -> tensor<32x5xf32> {
13+
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<1x5xi64>
14+
// CHECK-NEXT: %0 = stablehlo.reshape %arg2 : (tensor<5xi64>) -> tensor<1x5xi64>
15+
// CHECK-NEXT: %1 = stablehlo.subtract %0, %c : tensor<1x5xi64>
16+
// CHECK-NEXT: %2 = "stablehlo.gather"(%arg0, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1]>, indices_are_sorted = false, slice_sizes = array<i64: 32, 1>}> : (tensor<32x1024xf32>, tensor<1x5xi64>) -> tensor<32x5xf32>
17+
// CHECK-NEXT: %3 = "stablehlo.gather"(%arg1, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1]>, indices_are_sorted = false, slice_sizes = array<i64: 32, 1>}> : (tensor<32x1024xf32>, tensor<1x5xi64>) -> tensor<32x5xf32>
18+
// CHECK-NEXT: %4 = stablehlo.multiply %2, %3 : tensor<32x5xf32>
19+
// CHECK-NEXT: return %4 : tensor<32x5xf32>
20+
// CHECK-NEXT: }
21+
22+
func.func @gather_elementwise2(%arg0: tensor<32x1024xf32>, %arg1: tensor<5xi64>) -> tensor<32x5xf16> {
23+
%c = stablehlo.constant dense<1> : tensor<5xi64>
24+
%0 = stablehlo.convert %arg0 : (tensor<32x1024xf32>) -> tensor<32x1024xf16>
25+
%1 = stablehlo.subtract %arg1, %c : tensor<5xi64>
26+
%2 = stablehlo.reshape %1 : (tensor<5xi64>) -> tensor<1x5xi64>
27+
%3 = "stablehlo.gather"(%0, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1]>, indices_are_sorted = false, slice_sizes = array<i64: 32, 1>}> : (tensor<32x1024xf16>, tensor<1x5xi64>) -> tensor<32x5xf16>
28+
return %3 : tensor<32x5xf16>
29+
}
30+
31+
// CHECK: func.func @gather_elementwise2(%arg0: tensor<32x1024xf32>, %arg1: tensor<5xi64>) -> tensor<32x5xf16> {
32+
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<5xi64>
33+
// CHECK-NEXT: %0 = stablehlo.subtract %arg1, %c : tensor<5xi64>
34+
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<5xi64>) -> tensor<1x5xi64>
35+
// CHECK-NEXT: %2 = "stablehlo.gather"(%arg0, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [0], collapsed_slice_dims = [1], start_index_map = [1]>, indices_are_sorted = false, slice_sizes = array<i64: 32, 1>}> : (tensor<32x1024xf32>, tensor<1x5xi64>) -> tensor<32x5xf32>
36+
// CHECK-NEXT: %3 = stablehlo.convert %2 : (tensor<32x5xf32>) -> tensor<32x5xf16>
37+
// CHECK-NEXT: return %3 : tensor<32x5xf16>
38+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)