Skip to content

Commit 99d2b63

Browse files
authored
feat: more scatter simplifications for add/sub/div (#1890)
1 parent a83eae4 commit 99d2b63

File tree

13 files changed

+866
-511
lines changed

13 files changed

+866
-511
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ cc_library(
744744
":EnzymeXLACanonicalizers",
745745
":EnzymeXLADialectUtils",
746746
":EnzymeXLAOpsIncGen",
747+
"@com_google_absl//absl/status",
747748
"@enzyme//:EnzymeMLIR",
748749
"@llvm-project//llvm:Support",
749750
"@llvm-project//mlir:AffineDialect",

src/enzyme_ad/jax/Implementations/WhileLoopInfo.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "mlir/IR/Matchers.h"
1818

1919
#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
20-
#include "src/enzyme_ad/jax/Passes/StructuredTensors.h"
2120
#include "src/enzyme_ad/jax/Utils.h"
2221

2322
#include "llvm/ADT/TypeSwitch.h"

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 187 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
3434
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h"
3535
#include "src/enzyme_ad/jax/Passes/Passes.h"
36-
#include "src/enzyme_ad/jax/Passes/StructuredTensors.h"
3736
#include "src/enzyme_ad/jax/Utils.h"
3837
#include "stablehlo/dialect/Base.h"
3938
#include "stablehlo/dialect/ChloOps.h"
@@ -55,6 +54,7 @@
5554
#include "llvm/ADT/MapVector.h"
5655
#include <cstddef>
5756
#include <iterator>
57+
#include <mlir/IR/BuiltinAttributes.h>
5858
#include <mlir/IR/Value.h>
5959
#include <numeric>
6060
#define DEBUG_TYPE "enzymehloopt"
@@ -22676,29 +22676,30 @@ struct SplitConvolutionIntoReverseConvolution final
2267622676
}
2267722677
};
2267822678

22679-
struct ScatterMultiplySimplify final
22680-
: public CheckedOpRewritePattern<stablehlo::MulOp,
22681-
ScatterMultiplySimplify> {
22682-
using CheckedOpRewritePattern<
22683-
stablehlo::MulOp, ScatterMultiplySimplify>::CheckedOpRewritePattern;
22679+
template <typename OpTy, typename Child>
22680+
struct ScatterBinaryOpSimplifyBase
22681+
: public CheckedOpRewritePattern<OpTy, Child> {
22682+
using CheckedOpRewritePattern<OpTy, Child>::CheckedOpRewritePattern;
2268422683

22685-
LogicalResult matchAndRewriteImpl(stablehlo::MulOp op,
22686-
PatternRewriter &rewriter) const {
22684+
LogicalResult matchAndRewriteImpl(OpTy op, PatternRewriter &rewriter) const {
2268722685
auto lhs = op.getLhs();
2268822686
auto rhs = op.getRhs();
2268922687

2269022688
stablehlo::ScatterOp scatterOp;
2269122689
mlir::Value otherValue;
2269222690

22693-
auto lhsScatterOp = lhs.getDefiningOp<stablehlo::ScatterOp>();
22694-
auto rhsScatterOp = rhs.getDefiningOp<stablehlo::ScatterOp>();
22691+
auto lhsScatterOp = lhs.template getDefiningOp<stablehlo::ScatterOp>();
22692+
auto rhsScatterOp = rhs.template getDefiningOp<stablehlo::ScatterOp>();
22693+
bool lhsIsScatter;
2269522694
if (!lhsScatterOp && !rhsScatterOp) {
2269622695
return failure();
2269722696
} else {
2269822697
if (lhsScatterOp) {
22698+
lhsIsScatter = true;
2269922699
scatterOp = lhsScatterOp;
2270022700
otherValue = rhs;
2270122701
} else {
22702+
lhsIsScatter = false;
2270222703
scatterOp = rhsScatterOp;
2270322704
otherValue = lhs;
2270422705
}
@@ -22729,30 +22730,46 @@ struct ScatterMultiplySimplify final
2272922730

2273022731
SmallVector<int64_t> sliceSizes = computeGatherSliceSizes(scatterOp);
2273122732

22732-
if (isAllZeros) { // non scattered values before zeros
22733-
auto gatheredValues = stablehlo::GatherOp::create(
22734-
rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(),
22735-
getGatherDims(rewriter.getContext(),
22736-
scatterOp.getScatterDimensionNumbers()),
22737-
rewriter.getDenseI64ArrayAttr(sliceSizes),
22738-
scatterOp.getIndicesAreSortedAttr());
22733+
return ((Child *)this)
22734+
->rewriteScatterElementwise(op, rewriter, scatterOp, otherValue,
22735+
isAllZeros, isAllOnes, lhsIsScatter,
22736+
constSetIndexValue, sliceSizes);
22737+
}
22738+
22739+
// if !fuseIntoScatter
22740+
// %g = gather(%other_value, %scatter_indices)
22741+
// %tmp = elementwise_op(%g, %updates) if !lhsIsScatter else
22742+
// elementwise_op(%updates, %g)
22743+
// %new_scatter_op = scatter(%input, %scatter_indices, %tmp)
22744+
// else
22745+
// %new_scatter_op = scatter(%input, %scatter_indices, %updates) {
22746+
// elementwise_op(%arg1, %arg2) // based on lhsIsScatter
22747+
// }
22748+
template <auto CreateOpFn>
22749+
void gatherElementwiseSetIndex(OpTy op, PatternRewriter &rewriter,
22750+
stablehlo::ScatterOp scatterOp,
22751+
Value otherValue, Value scatterInput,
22752+
bool lhsIsScatter,
22753+
SplatElementsAttr constSetIndexValue,
22754+
SmallVectorImpl<int64_t> &sliceSizes,
22755+
bool fuseIntoScatter) const {
22756+
Value updateVal;
22757+
if (constSetIndexValue) {
22758+
updateVal = stablehlo::ConstantOp::create(
22759+
rewriter, scatterOp.getLoc(),
22760+
constSetIndexValue.resizeSplat(
22761+
cast<ShapedType>(scatterOp.getUpdates()[0].getType())));
22762+
} else {
22763+
updateVal = scatterOp.getUpdates()[0];
22764+
}
2273922765

22740-
Value mulRhs;
22741-
if (constSetIndexValue) {
22742-
mulRhs = stablehlo::ConstantOp::create(
22743-
rewriter, op.getLoc(),
22744-
constSetIndexValue.resizeSplat(
22745-
cast<ShapedType>(gatheredValues.getType())));
22746-
} else {
22747-
mulRhs = scatterOp.getUpdates()[0];
22748-
}
22749-
auto newUpdates =
22750-
stablehlo::MulOpCreate(rewriter, op.getLoc(), gatheredValues, mulRhs);
22766+
if (fuseIntoScatter) {
22767+
assert(scatterInput == otherValue);
2275122768

2275222769
auto newScatterOp = stablehlo::ScatterOp::create(
2275322770
rewriter, op.getLoc(), scatterOp.getResultTypes(),
22754-
scatterOp.getInputs(), scatterOp.getScatterIndices(),
22755-
ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(),
22771+
ValueRange(scatterInput), scatterOp.getScatterIndices(),
22772+
ValueRange(updateVal), scatterOp.getScatterDimensionNumbersAttr(),
2275622773
scatterOp.getIndicesAreSortedAttr(),
2275722774
scatterOp.getUniqueIndicesAttr());
2275822775

@@ -22764,9 +22781,67 @@ struct ScatterMultiplySimplify final
2276422781
block->addArgument(argType, op.getLoc());
2276522782
block->addArgument(argType, op.getLoc());
2276622783
rewriter.setInsertionPointToStart(block);
22767-
stablehlo::ReturnOp::create(rewriter, op.getLoc(), block->getArgument(1));
22784+
22785+
stablehlo::ReturnOp::create(
22786+
rewriter, op.getLoc(),
22787+
CreateOpFn(rewriter, op.getLoc(), block->getArgument(lhsIsScatter),
22788+
block->getArgument(!lhsIsScatter), std::nullopt));
2276822789

2276922790
rewriter.replaceOp(op, newScatterOp);
22791+
22792+
return;
22793+
}
22794+
22795+
auto gatheredValues =
22796+
stablehlo::GatherOp::create(
22797+
rewriter, scatterOp.getLoc(), otherValue,
22798+
scatterOp.getScatterIndices(),
22799+
getGatherDims(rewriter.getContext(),
22800+
scatterOp.getScatterDimensionNumbers()),
22801+
rewriter.getDenseI64ArrayAttr(sliceSizes),
22802+
scatterOp.getIndicesAreSortedAttr())
22803+
.getResult();
22804+
22805+
Value newUpdates = CreateOpFn(
22806+
rewriter, scatterOp.getLoc(), lhsIsScatter ? updateVal : gatheredValues,
22807+
lhsIsScatter ? gatheredValues : updateVal, std::nullopt);
22808+
22809+
auto newScatterOp = stablehlo::ScatterOp::create(
22810+
rewriter, op.getLoc(), scatterOp.getResultTypes(),
22811+
ValueRange(scatterInput), scatterOp.getScatterIndices(),
22812+
ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(),
22813+
scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
22814+
22815+
auto &updateRegion = newScatterOp.getUpdateComputation();
22816+
auto *block = rewriter.createBlock(&updateRegion);
22817+
auto elemType =
22818+
cast<RankedTensorType>(scatterOp.getResultTypes()[0]).getElementType();
22819+
auto argType = RankedTensorType::get({}, elemType);
22820+
block->addArgument(argType, op.getLoc());
22821+
block->addArgument(argType, op.getLoc());
22822+
rewriter.setInsertionPointToStart(block);
22823+
stablehlo::ReturnOp::create(rewriter, op.getLoc(), block->getArgument(1));
22824+
22825+
rewriter.replaceOp(op, newScatterOp);
22826+
}
22827+
};
22828+
22829+
struct ScatterMultiplySimplify final
22830+
: public ScatterBinaryOpSimplifyBase<stablehlo::MulOp,
22831+
ScatterMultiplySimplify> {
22832+
using ScatterBinaryOpSimplifyBase<
22833+
stablehlo::MulOp, ScatterMultiplySimplify>::ScatterBinaryOpSimplifyBase;
22834+
22835+
LogicalResult
22836+
rewriteScatterElementwise(stablehlo::MulOp op, PatternRewriter &rewriter,
22837+
stablehlo::ScatterOp scatterOp, Value otherValue,
22838+
bool isAllZeros, bool isAllOnes, bool lhsIsScatter,
22839+
SplatElementsAttr constSetIndexValue,
22840+
SmallVectorImpl<int64_t> &sliceSizes) const {
22841+
if (isAllZeros) { // non scattered values before zeros
22842+
gatherElementwiseSetIndex<stablehlo::MulOpCreate>(
22843+
op, rewriter, scatterOp, otherValue, scatterOp.getInputs()[0],
22844+
lhsIsScatter, constSetIndexValue, sliceSizes, false);
2277022845
return success();
2277122846
}
2277222847

@@ -22808,6 +22883,82 @@ struct ScatterMultiplySimplify final
2280822883
}
2280922884
};
2281022885

22886+
struct ScatterDivSimplify final
22887+
: public ScatterBinaryOpSimplifyBase<stablehlo::DivOp, ScatterDivSimplify> {
22888+
using ScatterBinaryOpSimplifyBase<
22889+
stablehlo::DivOp, ScatterDivSimplify>::ScatterBinaryOpSimplifyBase;
22890+
22891+
LogicalResult
22892+
rewriteScatterElementwise(stablehlo::DivOp op, PatternRewriter &rewriter,
22893+
stablehlo::ScatterOp scatterOp, Value otherValue,
22894+
bool isAllZeros, bool isAllOnes, bool lhsIsScatter,
22895+
SplatElementsAttr constSetIndexValue,
22896+
SmallVectorImpl<int64_t> &sliceSizes) const {
22897+
if (isAllOnes && !lhsIsScatter) {
22898+
// x / 1 -> setindex into x
22899+
gatherElementwiseSetIndex<stablehlo::DivOpCreate>(
22900+
op, rewriter, scatterOp, otherValue, otherValue, lhsIsScatter,
22901+
constSetIndexValue, sliceSizes, true);
22902+
return success();
22903+
}
22904+
22905+
return failure();
22906+
}
22907+
};
22908+
22909+
struct ScatterAddSimplify final
22910+
: public ScatterBinaryOpSimplifyBase<stablehlo::AddOp, ScatterAddSimplify> {
22911+
using ScatterBinaryOpSimplifyBase<
22912+
stablehlo::AddOp, ScatterAddSimplify>::ScatterBinaryOpSimplifyBase;
22913+
22914+
LogicalResult
22915+
rewriteScatterElementwise(stablehlo::AddOp op, PatternRewriter &rewriter,
22916+
stablehlo::ScatterOp scatterOp, Value otherValue,
22917+
bool isAllZeros, bool isAllOnes, bool lhsIsScatter,
22918+
SplatElementsAttr constSetIndexValue,
22919+
SmallVectorImpl<int64_t> &sliceSizes) const {
22920+
if (isAllZeros) {
22921+
gatherElementwiseSetIndex<stablehlo::AddOpCreate>(
22922+
op, rewriter, scatterOp, otherValue, otherValue, lhsIsScatter,
22923+
constSetIndexValue, sliceSizes, true);
22924+
return success();
22925+
}
22926+
22927+
return failure();
22928+
}
22929+
};
22930+
22931+
struct ScatterSubSimplify final
22932+
: public ScatterBinaryOpSimplifyBase<stablehlo::SubtractOp,
22933+
ScatterSubSimplify> {
22934+
using ScatterBinaryOpSimplifyBase<
22935+
stablehlo::SubtractOp, ScatterSubSimplify>::ScatterBinaryOpSimplifyBase;
22936+
22937+
LogicalResult
22938+
rewriteScatterElementwise(stablehlo::SubtractOp op, PatternRewriter &rewriter,
22939+
stablehlo::ScatterOp scatterOp, Value otherValue,
22940+
bool isAllZeros, bool isAllOnes, bool lhsIsScatter,
22941+
SplatElementsAttr constSetIndexValue,
22942+
SmallVectorImpl<int64_t> &sliceSizes) const {
22943+
if (isAllZeros) {
22944+
if (lhsIsScatter) {
22945+
otherValue =
22946+
stablehlo::NegOp::create(rewriter, op.getLoc(), otherValue);
22947+
gatherElementwiseSetIndex<stablehlo::AddOpCreate>(
22948+
op, rewriter, scatterOp, otherValue, otherValue, lhsIsScatter,
22949+
constSetIndexValue, sliceSizes, true);
22950+
} else {
22951+
gatherElementwiseSetIndex<stablehlo::SubtractOpCreate>(
22952+
op, rewriter, scatterOp, otherValue, otherValue, lhsIsScatter,
22953+
constSetIndexValue, sliceSizes, true);
22954+
}
22955+
return success();
22956+
}
22957+
22958+
return failure();
22959+
}
22960+
};
22961+
2281122962
struct GatherConstProp final
2281222963
: public CheckedOpRewritePattern<stablehlo::GatherOp, GatherConstProp> {
2281322964
using CheckedOpRewritePattern<stablehlo::GatherOp,
@@ -29108,8 +29259,8 @@ struct EnzymeHLOOptPass
2910829259
CSE<stablehlo::ConcatenateOp>, CSE<stablehlo::MaxOp>,
2910929260
CSE<stablehlo::NegOp>, CSE<stablehlo::AbsOp>,
2911029261
CSE<enzymexla::RotateOp>, CSE<enzymexla::WrapOp>,
29111-
CSE<enzymexla::ExtendOp>, CSEIota>(context,
29112-
PatternBenefit(65000));
29262+
CSE<enzymexla::ExtendOp>, CSEIota, CSE<stablehlo::GatherOp>,
29263+
CSE<stablehlo::ScatterOp>>(context, PatternBenefit(65000));
2911329264
}
2911429265

2911529266
if (passses & 256)
@@ -29282,6 +29433,9 @@ struct EnzymeHLOOptPass
2928229433
ConjComplexSimplify,
2928329434
SplitConvolutionIntoReverseConvolution,
2928429435
ScatterMultiplySimplify,
29436+
ScatterDivSimplify,
29437+
ScatterAddSimplify,
29438+
ScatterSubSimplify,
2928529439
UnaryElementwiseScatterSimplify,
2928629440
GatherElementwise,
2928729441
ElementwisePad,

0 commit comments

Comments
 (0)