@@ -19411,6 +19411,278 @@ struct SplitConvolutionIntoReverseConvolution final
1941119411 }
1941219412};
1941319413
19414+ stablehlo::GatherDimensionNumbersAttr
19415+ getGatherDims(mlir::MLIRContext *ctx,
19416+ stablehlo::ScatterDimensionNumbersAttr scatterDimNumbers) {
19417+ return stablehlo::GatherDimensionNumbersAttr::get(
19418+ ctx, scatterDimNumbers.getUpdateWindowDims(),
19419+ scatterDimNumbers.getInsertedWindowDims(),
19420+ scatterDimNumbers.getInputBatchingDims(),
19421+ scatterDimNumbers.getScatterIndicesBatchingDims(),
19422+ scatterDimNumbers.getScatterDimsToOperandDims(),
19423+ scatterDimNumbers.getIndexVectorDim());
19424+ }
19425+
19426+ bool isScatterSetindexOp(stablehlo::ScatterOp &scatterOp) {
19427+ auto &updateComputation = scatterOp.getUpdateComputation();
19428+
19429+ if (!updateComputation.hasOneBlock())
19430+ return false;
19431+
19432+ auto &block = updateComputation.front();
19433+ if (block.getNumArguments() != 2)
19434+ return false;
19435+
19436+ auto originalValue = block.getArgument(0);
19437+ auto updateValue = block.getArgument(1);
19438+
19439+ // The block should have exactly one operation (the return)
19440+ if (block.getOperations().size() != 1)
19441+ return false;
19442+
19443+ auto &returnOp = block.front();
19444+ auto stablehloReturn = dyn_cast<stablehlo::ReturnOp>(returnOp);
19445+ if (!stablehloReturn)
19446+ return false;
19447+
19448+ if (stablehloReturn.getNumOperands() != 1)
19449+ return false;
19450+
19451+ // The returned value should be the update value (second argument)
19452+ return stablehloReturn.getOperand(0) == updateValue;
19453+ }
19454+
19455+ SmallVector<int64_t> computeGatherSliceSizes(stablehlo::ScatterOp &scatterOp) {
19456+ auto inputType = cast<ShapedType>(scatterOp.getInputs()[0].getType());
19457+ auto updateType = cast<ShapedType>(scatterOp.getUpdates()[0].getType());
19458+ auto scatterDimNumbers = scatterOp.getScatterDimensionNumbers();
19459+
19460+ auto inputShape = inputType.getShape();
19461+ auto updateShape = updateType.getShape();
19462+
19463+ SmallVector<int64_t> sliceSizes(inputShape.size(), 1);
19464+
19465+ auto updateWindowDims = scatterDimNumbers.getUpdateWindowDims();
19466+ auto scatterIndicesBatchingDims =
19467+ scatterDimNumbers.getScatterIndicesBatchingDims();
19468+
19469+ // Map update window dimensions to their corresponding input dimensions
19470+ for (int64_t i = 0; i < updateWindowDims.size(); ++i) {
19471+ int64_t inputDim = updateWindowDims[i];
19472+
19473+ // Calculate the corresponding dimension in the update tensor
19474+ // Update tensor layout: [scatter_indices_batching_dims...,
19475+ // update_window_dims...]
19476+ int64_t updateDimIndex = scatterIndicesBatchingDims.size() + i;
19477+
19478+ if (updateDimIndex < updateShape.size()) {
19479+ sliceSizes[inputDim] = updateShape[updateDimIndex];
19480+ }
19481+ }
19482+
19483+ return sliceSizes;
19484+ }
19485+
19486+ struct ScatterMultiplySimplify final
19487+ : public CheckedOpRewritePattern<stablehlo::MulOp,
19488+ ScatterMultiplySimplify> {
19489+ using CheckedOpRewritePattern<
19490+ stablehlo::MulOp, ScatterMultiplySimplify>::CheckedOpRewritePattern;
19491+
19492+ LogicalResult matchAndRewriteImpl(stablehlo::MulOp op,
19493+ PatternRewriter &rewriter) const {
19494+ auto lhs = op.getLhs();
19495+ auto rhs = op.getRhs();
19496+
19497+ stablehlo::ScatterOp scatterOp;
19498+ mlir::Value otherValue;
19499+
19500+ auto lhsScatterOp = lhs.getDefiningOp<stablehlo::ScatterOp>();
19501+ auto rhsScatterOp = rhs.getDefiningOp<stablehlo::ScatterOp>();
19502+ if (!lhsScatterOp && !rhsScatterOp) {
19503+ return failure();
19504+ } else {
19505+ if (lhsScatterOp) {
19506+ scatterOp = lhsScatterOp;
19507+ otherValue = rhs;
19508+ } else {
19509+ scatterOp = rhsScatterOp;
19510+ otherValue = lhs;
19511+ }
19512+ }
19513+
19514+ if (scatterOp.getInputs().size() != 1)
19515+ return rewriter.notifyMatchFailure(
19516+ op, "ScatterOp with more than one input not supported");
19517+
19518+ auto input = scatterOp.getInputs()[0];
19519+ if (!matchPattern(input, m_AnyZeroFloat()) &&
19520+ !matchPattern(input, m_Zero()))
19521+ return rewriter.notifyMatchFailure(op, "ScatterOp with non-zero input");
19522+
19523+ if (!scatterOp.getResult(0).hasOneUse())
19524+ return rewriter.notifyMatchFailure(op, "ScatterOp with multiple uses");
19525+
19526+ if (!isScatterSetindexOp(scatterOp))
19527+ return rewriter.notifyMatchFailure(op, "ScatterOp with non-setindex");
19528+
19529+ auto scatterDimNumbers = scatterOp.getScatterDimensionNumbers();
19530+
19531+ SmallVector<int64_t> sliceSizes = computeGatherSliceSizes(scatterOp);
19532+
19533+ auto gatheredValues = rewriter.create<stablehlo::GatherOp>(
19534+ op.getLoc(), otherValue, scatterOp.getScatterIndices(),
19535+ getGatherDims(rewriter.getContext(), scatterDimNumbers),
19536+ rewriter.getDenseI64ArrayAttr(sliceSizes),
19537+ scatterOp.getIndicesAreSortedAttr());
19538+
19539+ auto newUpdates = rewriter.create<stablehlo::MulOp>(
19540+ op.getLoc(), gatheredValues, scatterOp.getUpdates()[0]);
19541+
19542+ auto newScatterOp = rewriter.create<stablehlo::ScatterOp>(
19543+ op.getLoc(), scatterOp.getResultTypes(), scatterOp.getInputs(),
19544+ scatterOp.getScatterIndices(), ValueRange(newUpdates),
19545+ scatterOp.getScatterDimensionNumbersAttr(),
19546+ scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
19547+ newScatterOp.getUpdateComputation().takeBody(
19548+ scatterOp.getUpdateComputation());
19549+ rewriter.replaceOp(op, newScatterOp);
19550+
19551+ return success();
19552+ }
19553+ };
19554+
19555+ struct GatherConstProp final
19556+ : public CheckedOpRewritePattern<stablehlo::GatherOp, GatherConstProp> {
19557+ using CheckedOpRewritePattern<stablehlo::GatherOp,
19558+ GatherConstProp>::CheckedOpRewritePattern;
19559+
19560+ LogicalResult matchAndRewriteImpl(stablehlo::GatherOp op,
19561+ PatternRewriter &rewriter) const {
19562+ DenseElementsAttr operandAttr;
19563+ if (!matchPattern(op.getOperand(), m_Constant(&operandAttr)))
19564+ return rewriter.notifyMatchFailure(op,
19565+ "GatherOp with non-constant input");
19566+
19567+ if (operandAttr.isSplat()) {
19568+ // In this case the indices don't matter and we can construct a new
19569+ // splatted result
19570+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
19571+ op, op.getType(), operandAttr.resizeSplat(op.getType()));
19572+ return success();
19573+ }
19574+
19575+ DenseElementsAttr startIndicesAttr;
19576+ if (!matchPattern(op.getStartIndices(), m_Constant(&startIndicesAttr))) {
19577+ return rewriter.notifyMatchFailure(
19578+ op, "GatherOp with non-constant start indices and unsplatted input");
19579+ }
19580+
19581+ stablehlo::Tensor operandTensor = stablehlo::constantOp(operandAttr);
19582+ stablehlo::Tensor startIndicesTensor =
19583+ stablehlo::constantOp(startIndicesAttr);
19584+ auto gatherDims = op.getDimensionNumbers();
19585+
19586+ auto sliceSizes = op.getSliceSizes();
19587+ auto elementType = rewriter.getIntegerType(64);
19588+ auto attrType =
19589+ RankedTensorType::get({(int64_t)sliceSizes.size()}, elementType);
19590+ auto sliceSizesAttr = DenseElementsAttr::get(attrType, sliceSizes);
19591+
19592+ auto result = stablehlo::gatherOp(
19593+ operandTensor, startIndicesTensor,
19594+ stablehlo::Axes(gatherDims.getOffsetDims()),
19595+ stablehlo::Axes(gatherDims.getCollapsedSliceDims()),
19596+ stablehlo::Axes(gatherDims.getOperandBatchingDims()),
19597+ stablehlo::Axes(gatherDims.getStartIndicesBatchingDims()),
19598+ stablehlo::Axes(gatherDims.getStartIndexMap()),
19599+ stablehlo::Axis(gatherDims.getIndexVectorDim()),
19600+ stablehlo::makeSizes(stablehlo::constantOp(sliceSizesAttr)),
19601+ op.getIndicesAreSorted(), op.getType());
19602+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(),
19603+ fromTensor(result));
19604+ return success();
19605+ }
19606+ };
19607+
19608+ struct UnaryElementwiseScatterSimplify final
19609+ : public CheckedOpTraitRewritePattern<OpTrait::Elementwise,
19610+ UnaryElementwiseScatterSimplify> {
19611+ using CheckedOpTraitRewritePattern<
19612+ OpTrait::Elementwise,
19613+ UnaryElementwiseScatterSimplify>::CheckedOpTraitRewritePattern;
19614+
19615+ LogicalResult matchAndRewriteImpl(Operation *op,
19616+ PatternRewriter &rewriter) const {
19617+ if (op->getNumOperands() != 1)
19618+ return rewriter.notifyMatchFailure(op, "not a unary elementwise op");
19619+
19620+ auto input = op->getOperand(0);
19621+ auto scatterOp = input.getDefiningOp<stablehlo::ScatterOp>();
19622+ if (!scatterOp)
19623+ return rewriter.notifyMatchFailure(op, "not a scatter op");
19624+
19625+ if (scatterOp.getInputs().size() != 1)
19626+ return rewriter.notifyMatchFailure(
19627+ op, "ScatterOp with more than one input not supported");
19628+
19629+ if (!scatterOp.getResult(0).hasOneUse())
19630+ return rewriter.notifyMatchFailure(op, "ScatterOp with multiple uses");
19631+
19632+ if (!isScatterSetindexOp(scatterOp))
19633+ return rewriter.notifyMatchFailure(op, "ScatterOp with non-setindex");
19634+
19635+ auto scatterInput = scatterOp.getInputs()[0];
19636+ DenseElementsAttr scatterInputAttr;
19637+ // In this case, we are will definitely increase the compute cost
19638+ if (!matchPattern(scatterInput, m_Constant(&scatterInputAttr)))
19639+ return rewriter.notifyMatchFailure(op,
19640+ "ScatterOp with non-constant input");
19641+
19642+ auto elemType =
19643+ cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
19644+
19645+ // TODO: support convert op. we need to rewrite the update computation to
19646+ // take the converted element type
19647+ if (isa<stablehlo::ConvertOp>(op))
19648+ return rewriter.notifyMatchFailure(op,
19649+ "ConvertOp not supported for now.");
19650+
19651+ // should get constant propagated
19652+ auto scatterInputElem = rewriter.create(
19653+ op->getLoc(), op->getName().getIdentifier(), ValueRange(scatterInput),
19654+ TypeRange{RankedTensorType::get(
19655+ cast<RankedTensorType>(scatterInput.getType()).getShape(),
19656+ elemType)},
19657+ op->getAttrs(), {}, {});
19658+
19659+ auto scatterUpdatesElem = rewriter.create(
19660+ op->getLoc(), op->getName().getIdentifier(),
19661+ ValueRange(scatterOp.getUpdates()),
19662+ TypeRange{RankedTensorType::get(
19663+ cast<RankedTensorType>(scatterOp.getUpdates()[0].getType())
19664+ .getShape(),
19665+ elemType)},
19666+ op->getAttrs(), {}, {});
19667+
19668+ auto resultType = RankedTensorType::get(
19669+ cast<RankedTensorType>(scatterOp.getResultTypes()[0]).getShape(),
19670+ elemType);
19671+
19672+ auto newScatterOp = rewriter.create<stablehlo::ScatterOp>(
19673+ op->getLoc(), TypeRange(resultType),
19674+ ValueRange(scatterInputElem->getResult(0)),
19675+ scatterOp.getScatterIndices(),
19676+ ValueRange(scatterUpdatesElem->getResult(0)),
19677+ scatterOp.getScatterDimensionNumbersAttr(),
19678+ scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
19679+ newScatterOp.getUpdateComputation().takeBody(
19680+ scatterOp.getUpdateComputation());
19681+ rewriter.replaceOp(op, newScatterOp->getResult(0));
19682+ return success();
19683+ }
19684+ };
19685+
1941419686/////////////// End Imported from stablehlo
1941519687
1941619688// clang-format off
@@ -19697,6 +19969,8 @@ struct EnzymeHLOOptPass
1969719969 BinaryConstProp<stablehlo::SubtractOp, stablehlo::subtractOp>,
1969819970 BinaryConstProp<stablehlo::XorOp, stablehlo::xorOp>>(context);
1969919971
19972+ patterns.add<GatherConstProp>(context);
19973+
1970019974 patterns.add<BinaryOpTransposeSimplify<stablehlo::AddOp>,
1970119975 BinaryOpTransposeSimplify<stablehlo::SubtractOp>,
1970219976 BinaryOpTransposeSimplify<stablehlo::MulOp>,
@@ -19929,7 +20203,9 @@ struct EnzymeHLOOptPass
1992920203 InvolutionSimplify<chlo::ConjOp>,
1993020204 RealConjSimplify,
1993120205 ConjComplexSimplify,
19932- SplitConvolutionIntoReverseConvolution
20206+ SplitConvolutionIntoReverseConvolution,
20207+ ScatterMultiplySimplify,
20208+ UnaryElementwiseScatterSimplify
1993320209 >(context);
1993420210
1993520211 patterns.add<SumToReduceWindow<stablehlo::AddOp>,
0 commit comments