Skip to content

Commit 32c0646

Browse files
authored
feat: transpose of scatter (#1042)
* feat: transpose of scatter * chore: remove redundant gather_simplify * chore: run fmt * fix: window dims * chore: run fmt * fix: remove gather_simplify
1 parent 1631a64 commit 32c0646

File tree

7 files changed

+130
-83
lines changed

7 files changed

+130
-83
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 81 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9477,43 +9477,6 @@ bool is_iota(ArrayRef<int64_t> idx) {
94779477
return true;
94789478
}
94799479

9480-
/// Converts gather ops to slice ops in case we have a single set of constant
9481-
/// indices.
9482-
struct GatherSimplify final
9483-
: CheckedOpRewritePattern<stablehlo::GatherOp, GatherSimplify> {
9484-
using CheckedOpRewritePattern::CheckedOpRewritePattern;
9485-
9486-
LogicalResult matchAndRewriteImpl(stablehlo::GatherOp op,
9487-
PatternRewriter &rewriter) const {
9488-
DenseIntElementsAttr startIndicesCst;
9489-
if (!matchPattern(op.getStartIndices(), m_Constant(&startIndicesCst)))
9490-
return failure();
9491-
9492-
{
9493-
DenseIntElementsAttr operandVals;
9494-
if (matchPattern(op.getOperand(), m_Constant(&operandVals))) {
9495-
auto out = stablehlo::gatherOp(
9496-
stablehlo::constantOp(operandVals),
9497-
stablehlo::constantOp(startIndicesCst),
9498-
stablehlo::Axes(op.getDimensionNumbers().getOffsetDims()),
9499-
stablehlo::Axes(op.getDimensionNumbers().getCollapsedSliceDims()),
9500-
stablehlo::Axes(op.getDimensionNumbers().getOperandBatchingDims()),
9501-
stablehlo::Axes(
9502-
op.getDimensionNumbers().getStartIndicesBatchingDims()),
9503-
stablehlo::Axes(op.getDimensionNumbers().getStartIndexMap()),
9504-
stablehlo::Axis(op.getDimensionNumbers().getIndexVectorDim()),
9505-
stablehlo::Sizes(op.getSliceSizes()), op.getIndicesAreSorted(),
9506-
op.getType());
9507-
9508-
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(),
9509-
fromTensor(out));
9510-
return success();
9511-
}
9512-
}
9513-
return failure();
9514-
}
9515-
};
9516-
95179480
struct CSEIota : CheckedOpRewritePattern<stablehlo::IotaOp, CSEIota> {
95189481
using CheckedOpRewritePattern::CheckedOpRewritePattern;
95199482

@@ -19718,6 +19681,72 @@ struct GatherElementwise
1971819681
}
1971919682
};
1972019683

19684+
SmallVector<int64_t> applyPermutation(ArrayRef<int64_t> dims,
19685+
ArrayRef<int64_t> permutation,
19686+
bool sort = false) {
19687+
SmallVector<int64_t> newDims(dims.size(), -1);
19688+
for (int64_t i = 0; i < dims.size(); ++i) {
19689+
newDims[i] = permutation[dims[i]];
19690+
}
19691+
19692+
if (sort)
19693+
llvm::sort(newDims);
19694+
19695+
return newDims;
19696+
}
19697+
19698+
struct TransposeScatter
19699+
: public CheckedOpRewritePattern<stablehlo::TransposeOp, TransposeScatter> {
19700+
using CheckedOpRewritePattern<stablehlo::TransposeOp,
19701+
TransposeScatter>::CheckedOpRewritePattern;
19702+
19703+
LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op,
19704+
PatternRewriter &rewriter) const {
19705+
auto scatterOp = op.getOperand().getDefiningOp<stablehlo::ScatterOp>();
19706+
if (!scatterOp)
19707+
return rewriter.notifyMatchFailure(op,
19708+
"TransposeOp with non-scatter input");
19709+
19710+
if (scatterOp.getInputs().size() != 1)
19711+
return rewriter.notifyMatchFailure(
19712+
op, "TransposeOp with scatter input with more than 1 operand");
19713+
19714+
if (!isOnlyUsedInOperation(scatterOp, op))
19715+
return failure();
19716+
19717+
auto transposedInput = rewriter.create<stablehlo::TransposeOp>(
19718+
op.getLoc(), scatterOp.getInputs()[0], op.getPermutation());
19719+
19720+
auto newScatterOp = rewriter.create<stablehlo::ScatterOp>(
19721+
op.getLoc(), TypeRange(op.getType()), ValueRange(transposedInput),
19722+
scatterOp.getScatterIndices(), scatterOp.getUpdates(),
19723+
transposeScatterDimensionNumbers(
19724+
scatterOp.getScatterDimensionNumbers(),
19725+
getInversePermutation(op.getPermutation()), rewriter),
19726+
scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
19727+
newScatterOp.getUpdateComputation().takeBody(
19728+
scatterOp.getUpdateComputation());
19729+
rewriter.replaceOp(op, newScatterOp->getResult(0));
19730+
return success();
19731+
}
19732+
19733+
private:
19734+
stablehlo::ScatterDimensionNumbersAttr transposeScatterDimensionNumbers(
19735+
stablehlo::ScatterDimensionNumbersAttr scatterDimNumbers,
19736+
SmallVector<int64_t> mapping, PatternRewriter &rewriter) const {
19737+
return stablehlo::ScatterDimensionNumbersAttr::get(
19738+
rewriter.getContext(), scatterDimNumbers.getUpdateWindowDims(),
19739+
applyPermutation(scatterDimNumbers.getInsertedWindowDims(), mapping,
19740+
true),
19741+
applyPermutation(scatterDimNumbers.getInputBatchingDims(), mapping,
19742+
true),
19743+
scatterDimNumbers.getScatterIndicesBatchingDims(),
19744+
applyPermutation(scatterDimNumbers.getScatterDimsToOperandDims(),
19745+
mapping, true),
19746+
scatterDimNumbers.getIndexVectorDim());
19747+
}
19748+
};
19749+
1972119750
/////////////// End Imported from stablehlo
1972219751

1972319752
// clang-format off
@@ -19928,21 +19957,20 @@ struct EnzymeHLOOptPass
1992819957
patterns.add<TransposeExtend>(context);
1992919958
patterns.add<TransposeRotate>(context);
1993019959

19931-
patterns
19932-
.add<AddSimplify, SubSimplify, AndSimplify, MaxSimplify, MinSimplify,
19933-
OrSimplify, XorSimplify, MulSimplify, DivSimplify, RemSimplify,
19934-
PowSimplify, NoopSlice, NoopReverse, SliceSlice, PadSimplify,
19935-
ShiftRightLogicalSimplify, NegativePadToSlice, SliceSimplify,
19936-
ConvertSimplify, TransposeSimplify, DotGeneralSimplify,
19937-
DynamicSliceToStatic, DynamicUpdateSliceElim, ReduceToReshape,
19938-
BroadcastToReshape, GatherSimplify, ReshapeEmptyBroadcast,
19939-
BroadcastReshape, ConstPropThroughBarrier,
19940-
ReplaceNegAddWithSubtract, SignAbsSimplify, AbsPositiveSimplify,
19941-
SimplifyBoundary<enzymexla::ExtendOp>,
19942-
SimplifyBoundary<enzymexla::WrapOp>,
19943-
SimplifyBoundary<enzymexla::RotateOp>, TransposeReshapeToBroadcast,
19944-
ReshapeTransposeToBroadcast, SelectBroadcastInDim>(
19945-
context, PatternBenefit(65000));
19960+
patterns.add<
19961+
AddSimplify, SubSimplify, AndSimplify, MaxSimplify, MinSimplify,
19962+
OrSimplify, XorSimplify, MulSimplify, DivSimplify, RemSimplify,
19963+
PowSimplify, NoopSlice, NoopReverse, SliceSlice, PadSimplify,
19964+
ShiftRightLogicalSimplify, NegativePadToSlice, SliceSimplify,
19965+
ConvertSimplify, TransposeSimplify, DotGeneralSimplify,
19966+
DynamicSliceToStatic, DynamicUpdateSliceElim, ReduceToReshape,
19967+
BroadcastToReshape, ReshapeEmptyBroadcast, BroadcastReshape,
19968+
ConstPropThroughBarrier, ReplaceNegAddWithSubtract, SignAbsSimplify,
19969+
AbsPositiveSimplify, SimplifyBoundary<enzymexla::ExtendOp>,
19970+
SimplifyBoundary<enzymexla::WrapOp>,
19971+
SimplifyBoundary<enzymexla::RotateOp>, TransposeReshapeToBroadcast,
19972+
ReshapeTransposeToBroadcast, SelectBroadcastInDim>(
19973+
context, PatternBenefit(65000));
1994619974

1994719975
patterns.add<IotaSimplify, BroadcastInDimSimplify, ConcatConstProp,
1994819976
DynamicUpdateSliceConstProp, PadSimplify>(
@@ -20130,7 +20158,7 @@ struct EnzymeHLOOptPass
2013020158
TransposeIota, TransposeReduceWindow, TransposeReduce,
2013120159
TransposeSelect, TransposeDynamicSlice, TransposeReverse,
2013220160
TransposeBatchNormTraining, TransposeBatchNormInference,
20133-
TransposeBatchNormGrad, TransposeIf>(context);
20161+
TransposeBatchNormGrad, TransposeIf, TransposeScatter>(context);
2013420162
patterns.add<TransposeElementwise>(true, context);
2013520163
}
2013620164

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,6 @@ def ApplyBroadcastToReshapePatterns : EnzymeHLOPatternOp<
135135
"broadcast_to_reshape"> {
136136
let patterns = ["BroadcastToReshape"];
137137
}
138-
def ApplyGatherSimplifyPatterns : EnzymeHLOPatternOp<
139-
"gather_simplify"> {
140-
let patterns = ["GatherSimplify"];
141-
}
142138
def ApplyNotConstProp : EnzymeHLOPatternOp<
143139
"not_const_prop"> {
144140
let patterns = ["UnaryConstProp<stablehlo::NotOp,stablehlo::notOp>"];
@@ -2037,3 +2033,7 @@ def ApplyUnaryElementwiseScatterSimplify : EnzymeHLOPatternOp<"unary_elementwise
20372033
def ApplyGatherElementwise : EnzymeHLOPatternOp<"gather_elementwise"> {
20382034
let patterns = ["GatherElementwise"];
20392035
}
2036+
2037+
def ApplyTransposeScatter : EnzymeHLOPatternOp<"transpose_scatter"> {
2038+
let patterns = ["TransposeScatter"];
2039+
}

src/enzyme_ad/jax/primitives.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def optimization_passes(
174174
"concat_to_broadcast<16>",
175175
"reduce_to_reshape<16>",
176176
"broadcast_to_reshape<16>",
177-
"gather_simplify<16>",
178177
"slice_internal",
179178
f"iota_simplify<16>({max_constant_threshold})",
180179
f"broadcast_in_dim_simplify<16>({max_constant_threshold})",
@@ -331,6 +330,7 @@ def optimization_passes(
331330
# "concat_to_onedim_dusslice",
332331
"scatter_multiply_simplify",
333332
"unary_elementwise_scatter_simplify",
333+
"gather_elementwise",
334334
]
335335

336336
# constant propagation patterns
@@ -431,6 +431,7 @@ def optimization_passes(
431431
"transpose_batch_norm_inference",
432432
"transpose_batch_norm_grad",
433433
"transpose_if",
434+
"transpose_scatter",
434435
]
435436
elif transpose_propagate == "down":
436437
transform_passes_list += [

test/lit_tests/mulscatter.mlir

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s
1+
// RUN: enzymexlamlir-opt %s --pass-pipeline='builtin.module(enzyme-hlo-opt{passses=65536},enzyme-hlo-opt)' | FileCheck %s
22

33
func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
44
%cst = stablehlo.constant dense<2.000000e+00> : tensor<24xf32>
@@ -21,13 +21,11 @@ func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1
2121
}
2222

2323
// CHECK: func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
24+
// CHECK: %[[CST_1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<1024x1024xf32>
2425
// CHECK: %[[CST:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<24xf32>
2526
// CHECK: %[[CST_0:.*]] = stablehlo.constant dense<1> : tensor<24x2xi64>
26-
// CHECK: %[[CST_1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<1024x1024xf32>
27-
// CHECK: %[[arg2_T:.*]] = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
28-
// CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[arg2_T]], %[[SCATTER_INDICES:.*]]) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<1024x1024xf32>, tensor<24x2xi64>) -> tensor<24xf32>
27+
// CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%arg2, %[[SCATTER_INDICES:.*]]) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<1024x1024xf32>, tensor<24x2xi64>) -> tensor<24xf32>
2928
// CHECK: %[[MUL:.*]] = stablehlo.multiply %[[GATHER]], %[[CST]] : tensor<24xf32>
3029
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[CST_1]], %[[SCATTER_INDICES]], %[[MUL]]) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
31-
// CHECK: %[[RESULT:.*]] = stablehlo.transpose %[[SCATTER]], dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
32-
// CHECK: return %[[RESULT]] : tensor<1024x1024xf32>
30+
// CHECK: return %[[SCATTER]] : tensor<1024x1024xf32>
3331
// CHECK: }
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: enzymexlamlir-opt %s --pass-pipeline='builtin.module(enzyme-hlo-opt{passses=65536},enzyme-hlo-opt)' | FileCheck %s
2+
3+
func.func @main(%arg0: tensor<5x2xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<4x3x2xf32>) -> tensor<5x2xf32> {
4+
%c = stablehlo.constant dense<[[[0, 1, 2, 3], [3, 1, 0, 2], [2, 4, 4, 2]]]> : tensor<1x3x4xi64>
5+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32>
6+
%1 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<4x3x2xf32>) -> tensor<2x3x4xf32>
7+
%2 = "stablehlo.scatter"(%0, %c, %1) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>}> ({
8+
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
9+
%4 = stablehlo.multiply %arg2, %arg3 : tensor<f32>
10+
stablehlo.return %4 : tensor<f32>
11+
}) : (tensor<2x5xf32>, tensor<1x3x4xi64>, tensor<2x3x4xf32>) -> tensor<2x5xf32>
12+
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32>
13+
return %3 : tensor<5x2xf32>
14+
}
15+
16+
// CHECK: func.func @main(%arg0: tensor<5x2xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<4x3x2xf32>) -> tensor<5x2xf32> {
17+
// CHECK-NEXT{LITERAL}: %c = stablehlo.constant dense<[[[0, 1, 2, 3], [3, 1, 0, 2], [2, 4, 4, 2]]]> : tensor<1x3x4xi64>
18+
// CHECK-NEXT: %0 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<4x3x2xf32>) -> tensor<2x3x4xf32>
19+
// CHECK-NEXT: %1 = "stablehlo.scatter"(%arg0, %c, %0) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>}> ({
20+
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
21+
// CHECK-NEXT: %2 = stablehlo.multiply %arg2, %arg3 : tensor<f32>
22+
// CHECK-NEXT: stablehlo.return %2 : tensor<f32>
23+
// CHECK-NEXT: }) : (tensor<5x2xf32>, tensor<1x3x4xi64>, tensor<2x3x4xf32>) -> tensor<5x2xf32>
24+
// CHECK-NEXT: return %1 : tensor<5x2xf32>
25+
// CHECK-NEXT: }

test/lit_tests/unaryscatter.mlir

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s
1+
// RUN: enzymexlamlir-opt %s --pass-pipeline='builtin.module(enzyme-hlo-opt{passses=65536},enzyme-hlo-opt)' | FileCheck %s
22

33
func.func @unaryscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
44
%cst = stablehlo.constant dense<2.000000e+00> : tensor<24xf32>
@@ -33,8 +33,7 @@ func.func @unaryscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tenso
3333
// CHECK-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
3434
// CHECK-NEXT: stablehlo.return %arg4 : tensor<f32>
3535
// CHECK-NEXT: }) : (tensor<1024x1024xf32>, tensor<24x2xi64>, tensor<24xf32>) -> tensor<1024x1024xf32>
36-
// CHECK-NEXT: %7 = stablehlo.transpose %6, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
37-
// CHECK-NEXT: return %7 : tensor<1024x1024xf32>
36+
// CHECK-NEXT: return %6 : tensor<1024x1024xf32>
3837
// CHECK-NEXT: }
3938

4039
func.func @expscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
@@ -70,8 +69,7 @@ func.func @expscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<
7069
// CHECK-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
7170
// CHECK-NEXT: stablehlo.return %arg4 : tensor<f32>
7271
// CHECK-NEXT: }) : (tensor<1024x1024xf32>, tensor<24x2xi64>, tensor<24xf32>) -> tensor<1024x1024xf32>
73-
// CHECK-NEXT: %7 = stablehlo.transpose %6, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
74-
// CHECK-NEXT: return %7 : tensor<1024x1024xf32>
72+
// CHECK-NEXT: return %6 : tensor<1024x1024xf32>
7573
// CHECK-NEXT: }
7674

7775
func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tensor<5x4xf32> {
@@ -99,23 +97,21 @@ func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tens
9997
}
10098

10199
// CHECK: func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tensor<5x4xf32> {
102-
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32>
100+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<5x4xf32>
103101
// CHECK-NEXT{LITERAL}: %c = stablehlo.constant dense<[[4, 5], [4, 5], [4, 5], [4, 5], [4, 5]]> : tensor<5x2xi64>
104102
// CHECK-NEXT: %c_0 = stablehlo.constant dense<[-1, 3, 7, 11, 15]> : tensor<5xi64>
105103
// CHECK-NEXT: %c_1 = stablehlo.constant dense<4> : tensor<5xi64>
106-
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x4xf32>) -> tensor<4x5xf32>
107-
// CHECK-NEXT: %1 = stablehlo.convert %arg1 : (tensor<5xui32>) -> tensor<5xi64>
108-
// CHECK-NEXT: %2 = stablehlo.add %1, %c_0 : tensor<5xi64>
109-
// CHECK-NEXT: %3 = stablehlo.divide %2, %c_1 : tensor<5xi64>
104+
// CHECK-NEXT: %0 = stablehlo.convert %arg1 : (tensor<5xui32>) -> tensor<5xi64>
105+
// CHECK-NEXT: %1 = stablehlo.add %0, %c_0 : tensor<5xi64>
106+
// CHECK-NEXT: %2 = stablehlo.divide %1, %c_1 : tensor<5xi64>
107+
// CHECK-NEXT: %3 = stablehlo.reshape %1 : (tensor<5xi64>) -> tensor<5x1xi64>
110108
// CHECK-NEXT: %4 = stablehlo.reshape %2 : (tensor<5xi64>) -> tensor<5x1xi64>
111-
// CHECK-NEXT: %5 = stablehlo.reshape %3 : (tensor<5xi64>) -> tensor<5x1xi64>
112-
// CHECK-NEXT: %6 = stablehlo.concatenate %4, %5, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
113-
// CHECK-NEXT: %7 = stablehlo.remainder %6, %c : tensor<5x2xi64>
114-
// CHECK-NEXT: %8 = "stablehlo.gather"(%0, %7) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<4x5xf32>, tensor<5x2xi64>) -> tensor<5xf32>
115-
// CHECK-NEXT: %9 = "stablehlo.scatter"(%cst, %7, %8) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
109+
// CHECK-NEXT: %5 = stablehlo.concatenate %3, %4, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
110+
// CHECK-NEXT: %6 = stablehlo.remainder %5, %c : tensor<5x2xi64>
111+
// CHECK-NEXT: %7 = "stablehlo.gather"(%arg0, %6) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<5x4xf32>, tensor<5x2xi64>) -> tensor<5xf32>
112+
// CHECK-NEXT: %8 = "stablehlo.scatter"(%cst, %6, %7) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
116113
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
117114
// CHECK-NEXT: stablehlo.return %arg3 : tensor<f32>
118-
// CHECK-NEXT: }) : (tensor<4x5xf32>, tensor<5x2xi64>, tensor<5xf32>) -> tensor<4x5xf32>
119-
// CHECK-NEXT: %10 = stablehlo.transpose %9, dims = [1, 0] : (tensor<4x5xf32>) -> tensor<5x4xf32>
120-
// CHECK-NEXT: return %10 : tensor<5x4xf32>
115+
// CHECK-NEXT: }) : (tensor<5x4xf32>, tensor<5x2xi64>, tensor<5xf32>) -> tensor<5x4xf32>
116+
// CHECK-NEXT: return %8 : tensor<5x4xf32>
121117
// CHECK-NEXT: }

test/test_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ def AllPipelines():
359359
concat_to_broadcast<16>;
360360
reduce_to_reshape<16>;
361361
broadcast_to_reshape<16>;
362-
gather_simplify<16>;
363362
iota_simplify<16>(1024);
364363
broadcast_in_dim_simplify<16>(1024);
365364
convert_concat<1>;

0 commit comments

Comments
 (0)