@@ -9477,43 +9477,6 @@ bool is_iota(ArrayRef<int64_t> idx) {
94779477 return true;
94789478}
94799479
9480- /// Converts gather ops to slice ops in case we have a single set of constant
9481- /// indices.
9482- struct GatherSimplify final
9483- : CheckedOpRewritePattern<stablehlo::GatherOp, GatherSimplify> {
9484- using CheckedOpRewritePattern::CheckedOpRewritePattern;
9485-
9486- LogicalResult matchAndRewriteImpl(stablehlo::GatherOp op,
9487- PatternRewriter &rewriter) const {
9488- DenseIntElementsAttr startIndicesCst;
9489- if (!matchPattern(op.getStartIndices(), m_Constant(&startIndicesCst)))
9490- return failure();
9491-
9492- {
9493- DenseIntElementsAttr operandVals;
9494- if (matchPattern(op.getOperand(), m_Constant(&operandVals))) {
9495- auto out = stablehlo::gatherOp(
9496- stablehlo::constantOp(operandVals),
9497- stablehlo::constantOp(startIndicesCst),
9498- stablehlo::Axes(op.getDimensionNumbers().getOffsetDims()),
9499- stablehlo::Axes(op.getDimensionNumbers().getCollapsedSliceDims()),
9500- stablehlo::Axes(op.getDimensionNumbers().getOperandBatchingDims()),
9501- stablehlo::Axes(
9502- op.getDimensionNumbers().getStartIndicesBatchingDims()),
9503- stablehlo::Axes(op.getDimensionNumbers().getStartIndexMap()),
9504- stablehlo::Axis(op.getDimensionNumbers().getIndexVectorDim()),
9505- stablehlo::Sizes(op.getSliceSizes()), op.getIndicesAreSorted(),
9506- op.getType());
9507-
9508- rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(),
9509- fromTensor(out));
9510- return success();
9511- }
9512- }
9513- return failure();
9514- }
9515- };
9516-
95179480struct CSEIota : CheckedOpRewritePattern<stablehlo::IotaOp, CSEIota> {
95189481 using CheckedOpRewritePattern::CheckedOpRewritePattern;
95199482
@@ -19718,6 +19681,72 @@ struct GatherElementwise
1971819681 }
1971919682};
1972019683
19684+ SmallVector<int64_t> applyPermutation(ArrayRef<int64_t> dims,
19685+ ArrayRef<int64_t> permutation,
19686+ bool sort = false) {
19687+ SmallVector<int64_t> newDims(dims.size(), -1);
19688+ for (int64_t i = 0; i < dims.size(); ++i) {
19689+ newDims[i] = permutation[dims[i]];
19690+ }
19691+
19692+ if (sort)
19693+ llvm::sort(newDims);
19694+
19695+ return newDims;
19696+ }
19697+
19698+ struct TransposeScatter
19699+ : public CheckedOpRewritePattern<stablehlo::TransposeOp, TransposeScatter> {
19700+ using CheckedOpRewritePattern<stablehlo::TransposeOp,
19701+ TransposeScatter>::CheckedOpRewritePattern;
19702+
19703+ LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op,
19704+ PatternRewriter &rewriter) const {
19705+ auto scatterOp = op.getOperand().getDefiningOp<stablehlo::ScatterOp>();
19706+ if (!scatterOp)
19707+ return rewriter.notifyMatchFailure(op,
19708+ "TransposeOp with non-scatter input");
19709+
19710+ if (scatterOp.getInputs().size() != 1)
19711+ return rewriter.notifyMatchFailure(
19712+ op, "TransposeOp with scatter input with more than 1 operand");
19713+
19714+ if (!isOnlyUsedInOperation(scatterOp, op))
19715+ return failure();
19716+
19717+ auto transposedInput = rewriter.create<stablehlo::TransposeOp>(
19718+ op.getLoc(), scatterOp.getInputs()[0], op.getPermutation());
19719+
19720+ auto newScatterOp = rewriter.create<stablehlo::ScatterOp>(
19721+ op.getLoc(), TypeRange(op.getType()), ValueRange(transposedInput),
19722+ scatterOp.getScatterIndices(), scatterOp.getUpdates(),
19723+ transposeScatterDimensionNumbers(
19724+ scatterOp.getScatterDimensionNumbers(),
19725+ getInversePermutation(op.getPermutation()), rewriter),
19726+ scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
19727+ newScatterOp.getUpdateComputation().takeBody(
19728+ scatterOp.getUpdateComputation());
19729+ rewriter.replaceOp(op, newScatterOp->getResult(0));
19730+ return success();
19731+ }
19732+
19733+ private:
19734+ stablehlo::ScatterDimensionNumbersAttr transposeScatterDimensionNumbers(
19735+ stablehlo::ScatterDimensionNumbersAttr scatterDimNumbers,
19736+ SmallVector<int64_t> mapping, PatternRewriter &rewriter) const {
19737+ return stablehlo::ScatterDimensionNumbersAttr::get(
19738+ rewriter.getContext(), scatterDimNumbers.getUpdateWindowDims(),
19739+ applyPermutation(scatterDimNumbers.getInsertedWindowDims(), mapping,
19740+ true),
19741+ applyPermutation(scatterDimNumbers.getInputBatchingDims(), mapping,
19742+ true),
19743+ scatterDimNumbers.getScatterIndicesBatchingDims(),
19744+ applyPermutation(scatterDimNumbers.getScatterDimsToOperandDims(),
19745+ mapping, true),
19746+ scatterDimNumbers.getIndexVectorDim());
19747+ }
19748+ };
19749+
1972119750/////////////// End Imported from stablehlo
1972219751
1972319752// clang-format off
@@ -19928,21 +19957,20 @@ struct EnzymeHLOOptPass
1992819957 patterns.add<TransposeExtend>(context);
1992919958 patterns.add<TransposeRotate>(context);
1993019959
19931- patterns
19932- .add<AddSimplify, SubSimplify, AndSimplify, MaxSimplify, MinSimplify,
19933- OrSimplify, XorSimplify, MulSimplify, DivSimplify, RemSimplify,
19934- PowSimplify, NoopSlice, NoopReverse, SliceSlice, PadSimplify,
19935- ShiftRightLogicalSimplify, NegativePadToSlice, SliceSimplify,
19936- ConvertSimplify, TransposeSimplify, DotGeneralSimplify,
19937- DynamicSliceToStatic, DynamicUpdateSliceElim, ReduceToReshape,
19938- BroadcastToReshape, GatherSimplify, ReshapeEmptyBroadcast,
19939- BroadcastReshape, ConstPropThroughBarrier,
19940- ReplaceNegAddWithSubtract, SignAbsSimplify, AbsPositiveSimplify,
19941- SimplifyBoundary<enzymexla::ExtendOp>,
19942- SimplifyBoundary<enzymexla::WrapOp>,
19943- SimplifyBoundary<enzymexla::RotateOp>, TransposeReshapeToBroadcast,
19944- ReshapeTransposeToBroadcast, SelectBroadcastInDim>(
19945- context, PatternBenefit(65000));
19960+ patterns.add<
19961+ AddSimplify, SubSimplify, AndSimplify, MaxSimplify, MinSimplify,
19962+ OrSimplify, XorSimplify, MulSimplify, DivSimplify, RemSimplify,
19963+ PowSimplify, NoopSlice, NoopReverse, SliceSlice, PadSimplify,
19964+ ShiftRightLogicalSimplify, NegativePadToSlice, SliceSimplify,
19965+ ConvertSimplify, TransposeSimplify, DotGeneralSimplify,
19966+ DynamicSliceToStatic, DynamicUpdateSliceElim, ReduceToReshape,
19967+ BroadcastToReshape, ReshapeEmptyBroadcast, BroadcastReshape,
19968+ ConstPropThroughBarrier, ReplaceNegAddWithSubtract, SignAbsSimplify,
19969+ AbsPositiveSimplify, SimplifyBoundary<enzymexla::ExtendOp>,
19970+ SimplifyBoundary<enzymexla::WrapOp>,
19971+ SimplifyBoundary<enzymexla::RotateOp>, TransposeReshapeToBroadcast,
19972+ ReshapeTransposeToBroadcast, SelectBroadcastInDim>(
19973+ context, PatternBenefit(65000));
1994619974
1994719975 patterns.add<IotaSimplify, BroadcastInDimSimplify, ConcatConstProp,
1994819976 DynamicUpdateSliceConstProp, PadSimplify>(
@@ -20130,7 +20158,7 @@ struct EnzymeHLOOptPass
2013020158 TransposeIota, TransposeReduceWindow, TransposeReduce,
2013120159 TransposeSelect, TransposeDynamicSlice, TransposeReverse,
2013220160 TransposeBatchNormTraining, TransposeBatchNormInference,
20133- TransposeBatchNormGrad, TransposeIf>(context);
20161+ TransposeBatchNormGrad, TransposeIf, TransposeScatter >(context);
2013420162 patterns.add<TransposeElementwise>(true, context);
2013520163 }
2013620164
0 commit comments