3333#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
3434#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"
3535#include "src/enzyme_ad/jax/Passes/Passes.h"
36- #include "src/enzyme_ad/jax/Passes/StructuredTensors.h"
3736#include "src/enzyme_ad/jax/Utils.h"
3837#include "stablehlo/dialect/Base.h"
3938#include "stablehlo/dialect/ChloOps.h"
5554#include "llvm/ADT/MapVector.h"
5655#include <cstddef>
5756#include <iterator>
57+ #include <mlir/IR/BuiltinAttributes.h>
5858#include <mlir/IR/Value.h>
5959#include <numeric>
6060#define DEBUG_TYPE "enzymehloopt"
@@ -22676,29 +22676,30 @@ struct SplitConvolutionIntoReverseConvolution final
2267622676 }
2267722677};
2267822678
22679- struct ScatterMultiplySimplify final
22680- : public CheckedOpRewritePattern<stablehlo::MulOp,
22681- ScatterMultiplySimplify> {
22682- using CheckedOpRewritePattern<
22683- stablehlo::MulOp, ScatterMultiplySimplify>::CheckedOpRewritePattern;
22679+ template <typename OpTy, typename Child>
22680+ struct ScatterBinaryOpSimplifyBase
22681+ : public CheckedOpRewritePattern<OpTy, Child> {
22682+ using CheckedOpRewritePattern<OpTy, Child>::CheckedOpRewritePattern;
2268422683
22685- LogicalResult matchAndRewriteImpl(stablehlo::MulOp op,
22686- PatternRewriter &rewriter) const {
22684+ LogicalResult matchAndRewriteImpl(OpTy op, PatternRewriter &rewriter) const {
2268722685 auto lhs = op.getLhs();
2268822686 auto rhs = op.getRhs();
2268922687
2269022688 stablehlo::ScatterOp scatterOp;
2269122689 mlir::Value otherValue;
2269222690
22693- auto lhsScatterOp = lhs.getDefiningOp<stablehlo::ScatterOp>();
22694- auto rhsScatterOp = rhs.getDefiningOp<stablehlo::ScatterOp>();
22691+ auto lhsScatterOp = lhs.template getDefiningOp<stablehlo::ScatterOp>();
22692+ auto rhsScatterOp = rhs.template getDefiningOp<stablehlo::ScatterOp>();
22693+ bool lhsIsScatter;
2269522694 if (!lhsScatterOp && !rhsScatterOp) {
2269622695 return failure();
2269722696 } else {
2269822697 if (lhsScatterOp) {
22698+ lhsIsScatter = true;
2269922699 scatterOp = lhsScatterOp;
2270022700 otherValue = rhs;
2270122701 } else {
22702+ lhsIsScatter = false;
2270222703 scatterOp = rhsScatterOp;
2270322704 otherValue = lhs;
2270422705 }
@@ -22729,30 +22730,46 @@ struct ScatterMultiplySimplify final
2272922730
2273022731 SmallVector<int64_t> sliceSizes = computeGatherSliceSizes(scatterOp);
2273122732
22732- if (isAllZeros) { // non scattered values before zeros
22733- auto gatheredValues = stablehlo::GatherOp::create(
22734- rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(),
22735- getGatherDims(rewriter.getContext(),
22736- scatterOp.getScatterDimensionNumbers()),
22737- rewriter.getDenseI64ArrayAttr(sliceSizes),
22738- scatterOp.getIndicesAreSortedAttr());
22733+ return ((Child *)this)
22734+ ->rewriteScatterElementwise(op, rewriter, scatterOp, otherValue,
22735+ isAllZeros, isAllOnes, lhsIsScatter,
22736+ constSetIndexValue, sliceSizes);
22737+ }
22738+
22739+ // if !fuseIntoScatter
22740+ // %g = gather(%other_value, %scatter_indices)
22741+ // %tmp = elementwise_op(%g, %updates) if !lhsIsScatter else
22742+ // elementwise_op(%updates, %g)
22743+ // %new_scatter_op = scatter(%input, %scatter_indices, %tmp)
22744+ // else
22745+ // %new_scatter_op = scatter(%input, %scatter_indices, %updates) {
22746+ // elementwise_op(%arg1, %arg2) // based on lhsIsScatter
22747+ // }
22748+ template <auto CreateOpFn>
22749+ void gatherElementwiseSetIndex(OpTy op, PatternRewriter &rewriter,
22750+ stablehlo::ScatterOp scatterOp,
22751+ Value otherValue, Value scatterInput,
22752+ bool lhsIsScatter,
22753+ SplatElementsAttr constSetIndexValue,
22754+ SmallVectorImpl<int64_t> &sliceSizes,
22755+ bool fuseIntoScatter) const {
22756+ Value updateVal;
22757+ if (constSetIndexValue) {
22758+ updateVal = stablehlo::ConstantOp::create(
22759+ rewriter, scatterOp.getLoc(),
22760+ constSetIndexValue.resizeSplat(
22761+ cast<ShapedType>(scatterOp.getUpdates()[0].getType())));
22762+ } else {
22763+ updateVal = scatterOp.getUpdates()[0];
22764+ }
2273922765
22740- Value mulRhs;
22741- if (constSetIndexValue) {
22742- mulRhs = stablehlo::ConstantOp::create(
22743- rewriter, op.getLoc(),
22744- constSetIndexValue.resizeSplat(
22745- cast<ShapedType>(gatheredValues.getType())));
22746- } else {
22747- mulRhs = scatterOp.getUpdates()[0];
22748- }
22749- auto newUpdates =
22750- stablehlo::MulOpCreate(rewriter, op.getLoc(), gatheredValues, mulRhs);
22766+ if (fuseIntoScatter) {
22767+ assert(scatterInput == otherValue);
2275122768
2275222769 auto newScatterOp = stablehlo::ScatterOp::create(
2275322770 rewriter, op.getLoc(), scatterOp.getResultTypes(),
22754- scatterOp.getInputs( ), scatterOp.getScatterIndices(),
22755- ValueRange(newUpdates ), scatterOp.getScatterDimensionNumbersAttr(),
22771+ ValueRange(scatterInput ), scatterOp.getScatterIndices(),
22772+ ValueRange(updateVal ), scatterOp.getScatterDimensionNumbersAttr(),
2275622773 scatterOp.getIndicesAreSortedAttr(),
2275722774 scatterOp.getUniqueIndicesAttr());
2275822775
@@ -22764,9 +22781,67 @@ struct ScatterMultiplySimplify final
2276422781 block->addArgument(argType, op.getLoc());
2276522782 block->addArgument(argType, op.getLoc());
2276622783 rewriter.setInsertionPointToStart(block);
22767- stablehlo::ReturnOp::create(rewriter, op.getLoc(), block->getArgument(1));
22784+
22785+ stablehlo::ReturnOp::create(
22786+ rewriter, op.getLoc(),
22787+ CreateOpFn(rewriter, op.getLoc(), block->getArgument(lhsIsScatter),
22788+ block->getArgument(!lhsIsScatter), std::nullopt));
2276822789
2276922790 rewriter.replaceOp(op, newScatterOp);
22791+
22792+ return;
22793+ }
22794+
22795+ auto gatheredValues =
22796+ stablehlo::GatherOp::create(
22797+ rewriter, scatterOp.getLoc(), otherValue,
22798+ scatterOp.getScatterIndices(),
22799+ getGatherDims(rewriter.getContext(),
22800+ scatterOp.getScatterDimensionNumbers()),
22801+ rewriter.getDenseI64ArrayAttr(sliceSizes),
22802+ scatterOp.getIndicesAreSortedAttr())
22803+ .getResult();
22804+
22805+ Value newUpdates = CreateOpFn(
22806+ rewriter, scatterOp.getLoc(), lhsIsScatter ? updateVal : gatheredValues,
22807+ lhsIsScatter ? gatheredValues : updateVal, std::nullopt);
22808+
22809+ auto newScatterOp = stablehlo::ScatterOp::create(
22810+ rewriter, op.getLoc(), scatterOp.getResultTypes(),
22811+ ValueRange(scatterInput), scatterOp.getScatterIndices(),
22812+ ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(),
22813+ scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
22814+
22815+ auto &updateRegion = newScatterOp.getUpdateComputation();
22816+ auto *block = rewriter.createBlock(&updateRegion);
22817+ auto elemType =
22818+ cast<RankedTensorType>(scatterOp.getResultTypes()[0]).getElementType();
22819+ auto argType = RankedTensorType::get({}, elemType);
22820+ block->addArgument(argType, op.getLoc());
22821+ block->addArgument(argType, op.getLoc());
22822+ rewriter.setInsertionPointToStart(block);
22823+ stablehlo::ReturnOp::create(rewriter, op.getLoc(), block->getArgument(1));
22824+
22825+ rewriter.replaceOp(op, newScatterOp);
22826+ }
22827+ };
22828+
22829+ struct ScatterMultiplySimplify final
22830+ : public ScatterBinaryOpSimplifyBase<stablehlo::MulOp,
22831+ ScatterMultiplySimplify> {
22832+ using ScatterBinaryOpSimplifyBase<
22833+ stablehlo::MulOp, ScatterMultiplySimplify>::ScatterBinaryOpSimplifyBase;
22834+
22835+ LogicalResult
22836+ rewriteScatterElementwise(stablehlo::MulOp op, PatternRewriter &rewriter,
22837+ stablehlo::ScatterOp scatterOp, Value otherValue,
22838+ bool isAllZeros, bool isAllOnes, bool lhsIsScatter,
22839+ SplatElementsAttr constSetIndexValue,
22840+ SmallVectorImpl<int64_t> &sliceSizes) const {
22841+ if (isAllZeros) { // non scattered values before zeros
22842+ gatherElementwiseSetIndex<stablehlo::MulOpCreate>(
22843+ op, rewriter, scatterOp, otherValue, scatterOp.getInputs()[0],
22844+ lhsIsScatter, constSetIndexValue, sliceSizes, false);
2277022845 return success();
2277122846 }
2277222847
@@ -22808,6 +22883,82 @@ struct ScatterMultiplySimplify final
2280822883 }
2280922884};
2281022885
22886+ struct ScatterDivSimplify final
22887+ : public ScatterBinaryOpSimplifyBase<stablehlo::DivOp, ScatterDivSimplify> {
22888+ using ScatterBinaryOpSimplifyBase<
22889+ stablehlo::DivOp, ScatterDivSimplify>::ScatterBinaryOpSimplifyBase;
22890+
22891+ LogicalResult
22892+ rewriteScatterElementwise(stablehlo::DivOp op, PatternRewriter &rewriter,
22893+ stablehlo::ScatterOp scatterOp, Value otherValue,
22894+ bool isAllZeros, bool isAllOnes, bool lhsIsScatter,
22895+ SplatElementsAttr constSetIndexValue,
22896+ SmallVectorImpl<int64_t> &sliceSizes) const {
22897+ if (isAllOnes && !lhsIsScatter) {
22898+ // x / 1 -> setindex into x
22899+ gatherElementwiseSetIndex<stablehlo::DivOpCreate>(
22900+ op, rewriter, scatterOp, otherValue, otherValue, lhsIsScatter,
22901+ constSetIndexValue, sliceSizes, true);
22902+ return success();
22903+ }
22904+
22905+ return failure();
22906+ }
22907+ };
22908+
22909+ struct ScatterAddSimplify final
22910+ : public ScatterBinaryOpSimplifyBase<stablehlo::AddOp, ScatterAddSimplify> {
22911+ using ScatterBinaryOpSimplifyBase<
22912+ stablehlo::AddOp, ScatterAddSimplify>::ScatterBinaryOpSimplifyBase;
22913+
22914+ LogicalResult
22915+ rewriteScatterElementwise(stablehlo::AddOp op, PatternRewriter &rewriter,
22916+ stablehlo::ScatterOp scatterOp, Value otherValue,
22917+ bool isAllZeros, bool isAllOnes, bool lhsIsScatter,
22918+ SplatElementsAttr constSetIndexValue,
22919+ SmallVectorImpl<int64_t> &sliceSizes) const {
22920+ if (isAllZeros) {
22921+ gatherElementwiseSetIndex<stablehlo::AddOpCreate>(
22922+ op, rewriter, scatterOp, otherValue, otherValue, lhsIsScatter,
22923+ constSetIndexValue, sliceSizes, true);
22924+ return success();
22925+ }
22926+
22927+ return failure();
22928+ }
22929+ };
22930+
22931+ struct ScatterSubSimplify final
22932+ : public ScatterBinaryOpSimplifyBase<stablehlo::SubtractOp,
22933+ ScatterSubSimplify> {
22934+ using ScatterBinaryOpSimplifyBase<
22935+ stablehlo::SubtractOp, ScatterSubSimplify>::ScatterBinaryOpSimplifyBase;
22936+
22937+ LogicalResult
22938+ rewriteScatterElementwise(stablehlo::SubtractOp op, PatternRewriter &rewriter,
22939+ stablehlo::ScatterOp scatterOp, Value otherValue,
22940+ bool isAllZeros, bool isAllOnes, bool lhsIsScatter,
22941+ SplatElementsAttr constSetIndexValue,
22942+ SmallVectorImpl<int64_t> &sliceSizes) const {
22943+ if (isAllZeros) {
22944+ if (lhsIsScatter) {
22945+ otherValue =
22946+ stablehlo::NegOp::create(rewriter, op.getLoc(), otherValue);
22947+ gatherElementwiseSetIndex<stablehlo::AddOpCreate>(
22948+ op, rewriter, scatterOp, otherValue, otherValue, lhsIsScatter,
22949+ constSetIndexValue, sliceSizes, true);
22950+ } else {
22951+ gatherElementwiseSetIndex<stablehlo::SubtractOpCreate>(
22952+ op, rewriter, scatterOp, otherValue, otherValue, lhsIsScatter,
22953+ constSetIndexValue, sliceSizes, true);
22954+ }
22955+ return success();
22956+ }
22957+
22958+ return failure();
22959+ }
22960+ };
22961+
2281122962struct GatherConstProp final
2281222963 : public CheckedOpRewritePattern<stablehlo::GatherOp, GatherConstProp> {
2281322964 using CheckedOpRewritePattern<stablehlo::GatherOp,
@@ -29108,8 +29259,8 @@ struct EnzymeHLOOptPass
2910829259 CSE<stablehlo::ConcatenateOp>, CSE<stablehlo::MaxOp>,
2910929260 CSE<stablehlo::NegOp>, CSE<stablehlo::AbsOp>,
2911029261 CSE<enzymexla::RotateOp>, CSE<enzymexla::WrapOp>,
29111- CSE<enzymexla::ExtendOp>, CSEIota>(context ,
29112- PatternBenefit(65000));
29262+ CSE<enzymexla::ExtendOp>, CSEIota, CSE<stablehlo::GatherOp> ,
29263+ CSE<stablehlo::ScatterOp>>(context, PatternBenefit(65000));
2911329264 }
2911429265
2911529266 if (passses & 256)
@@ -29282,6 +29433,9 @@ struct EnzymeHLOOptPass
2928229433 ConjComplexSimplify,
2928329434 SplitConvolutionIntoReverseConvolution,
2928429435 ScatterMultiplySimplify,
29436+ ScatterDivSimplify,
29437+ ScatterAddSimplify,
29438+ ScatterSubSimplify,
2928529439 UnaryElementwiseScatterSimplify,
2928629440 GatherElementwise,
2928729441 ElementwisePad,
0 commit comments