Skip to content

Commit 756cf4c

Browse files
authored
feat: more scatter gather optimization patterns (#1033)
* feat: simplify mul of scatter * chore: run fmt * feat: gather constant propagation * feat: unary elementwise scatter simplify
1 parent f7a8d0f commit 756cf4c

File tree

5 files changed

+454
-1
lines changed

5 files changed

+454
-1
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 277 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19411,6 +19411,278 @@ struct SplitConvolutionIntoReverseConvolution final
1941119411
}
1941219412
};
1941319413

19414+
stablehlo::GatherDimensionNumbersAttr
19415+
getGatherDims(mlir::MLIRContext *ctx,
19416+
stablehlo::ScatterDimensionNumbersAttr scatterDimNumbers) {
19417+
return stablehlo::GatherDimensionNumbersAttr::get(
19418+
ctx, scatterDimNumbers.getUpdateWindowDims(),
19419+
scatterDimNumbers.getInsertedWindowDims(),
19420+
scatterDimNumbers.getInputBatchingDims(),
19421+
scatterDimNumbers.getScatterIndicesBatchingDims(),
19422+
scatterDimNumbers.getScatterDimsToOperandDims(),
19423+
scatterDimNumbers.getIndexVectorDim());
19424+
}
19425+
19426+
bool isScatterSetindexOp(stablehlo::ScatterOp &scatterOp) {
19427+
auto &updateComputation = scatterOp.getUpdateComputation();
19428+
19429+
if (!updateComputation.hasOneBlock())
19430+
return false;
19431+
19432+
auto &block = updateComputation.front();
19433+
if (block.getNumArguments() != 2)
19434+
return false;
19435+
19436+
auto originalValue = block.getArgument(0);
19437+
auto updateValue = block.getArgument(1);
19438+
19439+
// The block should have exactly one operation (the return)
19440+
if (block.getOperations().size() != 1)
19441+
return false;
19442+
19443+
auto &returnOp = block.front();
19444+
auto stablehloReturn = dyn_cast<stablehlo::ReturnOp>(returnOp);
19445+
if (!stablehloReturn)
19446+
return false;
19447+
19448+
if (stablehloReturn.getNumOperands() != 1)
19449+
return false;
19450+
19451+
// The returned value should be the update value (second argument)
19452+
return stablehloReturn.getOperand(0) == updateValue;
19453+
}
19454+
19455+
SmallVector<int64_t> computeGatherSliceSizes(stablehlo::ScatterOp &scatterOp) {
19456+
auto inputType = cast<ShapedType>(scatterOp.getInputs()[0].getType());
19457+
auto updateType = cast<ShapedType>(scatterOp.getUpdates()[0].getType());
19458+
auto scatterDimNumbers = scatterOp.getScatterDimensionNumbers();
19459+
19460+
auto inputShape = inputType.getShape();
19461+
auto updateShape = updateType.getShape();
19462+
19463+
SmallVector<int64_t> sliceSizes(inputShape.size(), 1);
19464+
19465+
auto updateWindowDims = scatterDimNumbers.getUpdateWindowDims();
19466+
auto scatterIndicesBatchingDims =
19467+
scatterDimNumbers.getScatterIndicesBatchingDims();
19468+
19469+
// Map update window dimensions to their corresponding input dimensions
19470+
for (int64_t i = 0; i < updateWindowDims.size(); ++i) {
19471+
int64_t inputDim = updateWindowDims[i];
19472+
19473+
// Calculate the corresponding dimension in the update tensor
19474+
// Update tensor layout: [scatter_indices_batching_dims...,
19475+
// update_window_dims...]
19476+
int64_t updateDimIndex = scatterIndicesBatchingDims.size() + i;
19477+
19478+
if (updateDimIndex < updateShape.size()) {
19479+
sliceSizes[inputDim] = updateShape[updateDimIndex];
19480+
}
19481+
}
19482+
19483+
return sliceSizes;
19484+
}
19485+
19486+
struct ScatterMultiplySimplify final
19487+
: public CheckedOpRewritePattern<stablehlo::MulOp,
19488+
ScatterMultiplySimplify> {
19489+
using CheckedOpRewritePattern<
19490+
stablehlo::MulOp, ScatterMultiplySimplify>::CheckedOpRewritePattern;
19491+
19492+
LogicalResult matchAndRewriteImpl(stablehlo::MulOp op,
19493+
PatternRewriter &rewriter) const {
19494+
auto lhs = op.getLhs();
19495+
auto rhs = op.getRhs();
19496+
19497+
stablehlo::ScatterOp scatterOp;
19498+
mlir::Value otherValue;
19499+
19500+
auto lhsScatterOp = lhs.getDefiningOp<stablehlo::ScatterOp>();
19501+
auto rhsScatterOp = rhs.getDefiningOp<stablehlo::ScatterOp>();
19502+
if (!lhsScatterOp && !rhsScatterOp) {
19503+
return failure();
19504+
} else {
19505+
if (lhsScatterOp) {
19506+
scatterOp = lhsScatterOp;
19507+
otherValue = rhs;
19508+
} else {
19509+
scatterOp = rhsScatterOp;
19510+
otherValue = lhs;
19511+
}
19512+
}
19513+
19514+
if (scatterOp.getInputs().size() != 1)
19515+
return rewriter.notifyMatchFailure(
19516+
op, "ScatterOp with more than one input not supported");
19517+
19518+
auto input = scatterOp.getInputs()[0];
19519+
if (!matchPattern(input, m_AnyZeroFloat()) &&
19520+
!matchPattern(input, m_Zero()))
19521+
return rewriter.notifyMatchFailure(op, "ScatterOp with non-zero input");
19522+
19523+
if (!scatterOp.getResult(0).hasOneUse())
19524+
return rewriter.notifyMatchFailure(op, "ScatterOp with multiple uses");
19525+
19526+
if (!isScatterSetindexOp(scatterOp))
19527+
return rewriter.notifyMatchFailure(op, "ScatterOp with non-setindex");
19528+
19529+
auto scatterDimNumbers = scatterOp.getScatterDimensionNumbers();
19530+
19531+
SmallVector<int64_t> sliceSizes = computeGatherSliceSizes(scatterOp);
19532+
19533+
auto gatheredValues = rewriter.create<stablehlo::GatherOp>(
19534+
op.getLoc(), otherValue, scatterOp.getScatterIndices(),
19535+
getGatherDims(rewriter.getContext(), scatterDimNumbers),
19536+
rewriter.getDenseI64ArrayAttr(sliceSizes),
19537+
scatterOp.getIndicesAreSortedAttr());
19538+
19539+
auto newUpdates = rewriter.create<stablehlo::MulOp>(
19540+
op.getLoc(), gatheredValues, scatterOp.getUpdates()[0]);
19541+
19542+
auto newScatterOp = rewriter.create<stablehlo::ScatterOp>(
19543+
op.getLoc(), scatterOp.getResultTypes(), scatterOp.getInputs(),
19544+
scatterOp.getScatterIndices(), ValueRange(newUpdates),
19545+
scatterOp.getScatterDimensionNumbersAttr(),
19546+
scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
19547+
newScatterOp.getUpdateComputation().takeBody(
19548+
scatterOp.getUpdateComputation());
19549+
rewriter.replaceOp(op, newScatterOp);
19550+
19551+
return success();
19552+
}
19553+
};
19554+
19555+
struct GatherConstProp final
19556+
: public CheckedOpRewritePattern<stablehlo::GatherOp, GatherConstProp> {
19557+
using CheckedOpRewritePattern<stablehlo::GatherOp,
19558+
GatherConstProp>::CheckedOpRewritePattern;
19559+
19560+
LogicalResult matchAndRewriteImpl(stablehlo::GatherOp op,
19561+
PatternRewriter &rewriter) const {
19562+
DenseElementsAttr operandAttr;
19563+
if (!matchPattern(op.getOperand(), m_Constant(&operandAttr)))
19564+
return rewriter.notifyMatchFailure(op,
19565+
"GatherOp with non-constant input");
19566+
19567+
if (operandAttr.isSplat()) {
19568+
// In this case the indices don't matter and we can construct a new
19569+
// splatted result
19570+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
19571+
op, op.getType(), operandAttr.resizeSplat(op.getType()));
19572+
return success();
19573+
}
19574+
19575+
DenseElementsAttr startIndicesAttr;
19576+
if (!matchPattern(op.getStartIndices(), m_Constant(&startIndicesAttr))) {
19577+
return rewriter.notifyMatchFailure(
19578+
op, "GatherOp with non-constant start indices and unsplatted input");
19579+
}
19580+
19581+
stablehlo::Tensor operandTensor = stablehlo::constantOp(operandAttr);
19582+
stablehlo::Tensor startIndicesTensor =
19583+
stablehlo::constantOp(startIndicesAttr);
19584+
auto gatherDims = op.getDimensionNumbers();
19585+
19586+
auto sliceSizes = op.getSliceSizes();
19587+
auto elementType = rewriter.getIntegerType(64);
19588+
auto attrType =
19589+
RankedTensorType::get({(int64_t)sliceSizes.size()}, elementType);
19590+
auto sliceSizesAttr = DenseElementsAttr::get(attrType, sliceSizes);
19591+
19592+
auto result = stablehlo::gatherOp(
19593+
operandTensor, startIndicesTensor,
19594+
stablehlo::Axes(gatherDims.getOffsetDims()),
19595+
stablehlo::Axes(gatherDims.getCollapsedSliceDims()),
19596+
stablehlo::Axes(gatherDims.getOperandBatchingDims()),
19597+
stablehlo::Axes(gatherDims.getStartIndicesBatchingDims()),
19598+
stablehlo::Axes(gatherDims.getStartIndexMap()),
19599+
stablehlo::Axis(gatherDims.getIndexVectorDim()),
19600+
stablehlo::makeSizes(stablehlo::constantOp(sliceSizesAttr)),
19601+
op.getIndicesAreSorted(), op.getType());
19602+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, op.getType(),
19603+
fromTensor(result));
19604+
return success();
19605+
}
19606+
};
19607+
19608+
struct UnaryElementwiseScatterSimplify final
19609+
: public CheckedOpTraitRewritePattern<OpTrait::Elementwise,
19610+
UnaryElementwiseScatterSimplify> {
19611+
using CheckedOpTraitRewritePattern<
19612+
OpTrait::Elementwise,
19613+
UnaryElementwiseScatterSimplify>::CheckedOpTraitRewritePattern;
19614+
19615+
LogicalResult matchAndRewriteImpl(Operation *op,
19616+
PatternRewriter &rewriter) const {
19617+
if (op->getNumOperands() != 1)
19618+
return rewriter.notifyMatchFailure(op, "not a unary elementwise op");
19619+
19620+
auto input = op->getOperand(0);
19621+
auto scatterOp = input.getDefiningOp<stablehlo::ScatterOp>();
19622+
if (!scatterOp)
19623+
return rewriter.notifyMatchFailure(op, "not a scatter op");
19624+
19625+
if (scatterOp.getInputs().size() != 1)
19626+
return rewriter.notifyMatchFailure(
19627+
op, "ScatterOp with more than one input not supported");
19628+
19629+
if (!scatterOp.getResult(0).hasOneUse())
19630+
return rewriter.notifyMatchFailure(op, "ScatterOp with multiple uses");
19631+
19632+
if (!isScatterSetindexOp(scatterOp))
19633+
return rewriter.notifyMatchFailure(op, "ScatterOp with non-setindex");
19634+
19635+
auto scatterInput = scatterOp.getInputs()[0];
19636+
DenseElementsAttr scatterInputAttr;
19637+
// In this case, we are will definitely increase the compute cost
19638+
if (!matchPattern(scatterInput, m_Constant(&scatterInputAttr)))
19639+
return rewriter.notifyMatchFailure(op,
19640+
"ScatterOp with non-constant input");
19641+
19642+
auto elemType =
19643+
cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
19644+
19645+
// TODO: support convert op. we need to rewrite the update computation to
19646+
// take the converted element type
19647+
if (isa<stablehlo::ConvertOp>(op))
19648+
return rewriter.notifyMatchFailure(op,
19649+
"ConvertOp not supported for now.");
19650+
19651+
// should get constant propagated
19652+
auto scatterInputElem = rewriter.create(
19653+
op->getLoc(), op->getName().getIdentifier(), ValueRange(scatterInput),
19654+
TypeRange{RankedTensorType::get(
19655+
cast<RankedTensorType>(scatterInput.getType()).getShape(),
19656+
elemType)},
19657+
op->getAttrs(), {}, {});
19658+
19659+
auto scatterUpdatesElem = rewriter.create(
19660+
op->getLoc(), op->getName().getIdentifier(),
19661+
ValueRange(scatterOp.getUpdates()),
19662+
TypeRange{RankedTensorType::get(
19663+
cast<RankedTensorType>(scatterOp.getUpdates()[0].getType())
19664+
.getShape(),
19665+
elemType)},
19666+
op->getAttrs(), {}, {});
19667+
19668+
auto resultType = RankedTensorType::get(
19669+
cast<RankedTensorType>(scatterOp.getResultTypes()[0]).getShape(),
19670+
elemType);
19671+
19672+
auto newScatterOp = rewriter.create<stablehlo::ScatterOp>(
19673+
op->getLoc(), TypeRange(resultType),
19674+
ValueRange(scatterInputElem->getResult(0)),
19675+
scatterOp.getScatterIndices(),
19676+
ValueRange(scatterUpdatesElem->getResult(0)),
19677+
scatterOp.getScatterDimensionNumbersAttr(),
19678+
scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
19679+
newScatterOp.getUpdateComputation().takeBody(
19680+
scatterOp.getUpdateComputation());
19681+
rewriter.replaceOp(op, newScatterOp->getResult(0));
19682+
return success();
19683+
}
19684+
};
19685+
1941419686
/////////////// End Imported from stablehlo
1941519687

1941619688
// clang-format off
@@ -19697,6 +19969,8 @@ struct EnzymeHLOOptPass
1969719969
BinaryConstProp<stablehlo::SubtractOp, stablehlo::subtractOp>,
1969819970
BinaryConstProp<stablehlo::XorOp, stablehlo::xorOp>>(context);
1969919971

19972+
patterns.add<GatherConstProp>(context);
19973+
1970019974
patterns.add<BinaryOpTransposeSimplify<stablehlo::AddOp>,
1970119975
BinaryOpTransposeSimplify<stablehlo::SubtractOp>,
1970219976
BinaryOpTransposeSimplify<stablehlo::MulOp>,
@@ -19929,7 +20203,9 @@ struct EnzymeHLOOptPass
1992920203
InvolutionSimplify<chlo::ConjOp>,
1993020204
RealConjSimplify,
1993120205
ConjComplexSimplify,
19932-
SplitConvolutionIntoReverseConvolution
20206+
SplitConvolutionIntoReverseConvolution,
20207+
ScatterMultiplySimplify,
20208+
UnaryElementwiseScatterSimplify
1993320209
>(context);
1993420210

1993520211
patterns.add<SumToReduceWindow<stablehlo::AddOp>,

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,3 +2015,21 @@ def ApplySplitConvolutionIntoReverseConvolution : EnzymeHLOPatternOp<
20152015
"SplitConvolutionIntoReverseConvolution"
20162016
];
20172017
}
2018+
2019+
def ApplyScatterMultiplySimplify : EnzymeHLOPatternOp<"scatter_multiply_simplify"> {
2020+
let patterns = [
2021+
"ScatterMultiplySimplify"
2022+
];
2023+
}
2024+
2025+
def ApplyGatherConstProp : EnzymeHLOPatternOp<"gather_const_prop"> {
2026+
let patterns = [
2027+
"GatherConstProp"
2028+
];
2029+
}
2030+
2031+
def ApplyUnaryElementwiseScatterSimplify : EnzymeHLOPatternOp<"unary_elementwise_scatter_simplify"> {
2032+
let patterns = [
2033+
"UnaryElementwiseScatterSimplify"
2034+
];
2035+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s
2+
3+
func.func @gather_constprop1(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<6x4xf64>) -> tensor<6x4xf64> {
4+
%c = stablehlo.constant dense<1> : tensor<24x2xi64>
5+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<1024x1024xf64>
6+
%0 = stablehlo.concatenate %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, dim = 0 : (tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<24xi64>
7+
%1 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<6xi64>) -> tensor<6x4xi64>
8+
%2 = stablehlo.reshape %0 : (tensor<24xi64>) -> tensor<24x1xi64>
9+
%3 = stablehlo.reshape %1 : (tensor<6x4xi64>) -> tensor<24x1xi64>
10+
%4 = stablehlo.concatenate %2, %3, dim = 1 : (tensor<24x1xi64>, tensor<24x1xi64>) -> tensor<24x2xi64>
11+
%5 = stablehlo.subtract %4, %c : tensor<24x2xi64>
12+
%6 = "stablehlo.gather"(%cst, %5) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<1024x1024xf64>, tensor<24x2xi64>) -> tensor<24xf64>
13+
%7 = stablehlo.reshape %6 : (tensor<24xf64>) -> tensor<6x4xf64>
14+
%8 = stablehlo.multiply %7, %arg2 : tensor<6x4xf64>
15+
return %8 : tensor<6x4xf64>
16+
}
17+
18+
// CHECK: func.func @gather_constprop1(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<6x4xf64>) -> tensor<6x4xf64> {
19+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<6x4xf64>
20+
// CHECK-NEXT: %0 = stablehlo.multiply %cst, %arg2 : tensor<6x4xf64>
21+
// CHECK-NEXT: return %0 : tensor<6x4xf64>
22+
// CHECK-NEXT: }
23+
24+
func.func @gather_constprop2(%arg0: tensor<6x4xf64>) -> tensor<6x4xf64> {
25+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<1024x1024xf64>
26+
%c = stablehlo.constant dense<[[0, 0], [3, 0], [5, 0], [11, 0], [0, 31], [3, 31], [5, 31], [11, 31], [0, 32], [3, 32], [5, 32], [11, 32], [0, 34], [3, 34], [5, 34], [11, 34], [0, 11], [3, 11], [5, 11], [11, 11], [0, 110], [3, 110], [5, 110], [11, 110]]> : tensor<24x2xi64>
27+
%0 = "stablehlo.gather"(%cst, %c) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<1024x1024xf64>, tensor<24x2xi64>) -> tensor<24xf64>
28+
%1 = stablehlo.reshape %0 : (tensor<24xf64>) -> tensor<6x4xf64>
29+
return %1 : tensor<6x4xf64>
30+
}
31+
32+
// CHECK: func.func @gather_constprop2(%arg0: tensor<6x4xf64>) -> tensor<6x4xf64> {
33+
// CHECK-NEXT: %cst = stablehlo.constant dense<1.000000e+00> : tensor<6x4xf64>
34+
// CHECK-NEXT: return %cst : tensor<6x4xf64>
35+
// CHECK-NEXT: }
36+
37+
func.func @gather_constprop3(%arg0: tensor<6x4xf64>) -> tensor<6x4xf64> {
38+
%cst = stablehlo.constant dense<"0x0000803F000050410000C841000014420000444200007442000092420000AA420000C2420000DA420000F2420000054300000040000060410000D041000018420000484200007842000094420000AC420000C4420000DC420000F4420000064300004040000070410000D84100001C4200004C4200007C42000096420000AE420000C6420000DE420000F6420000074300008040000080410000E041000020420000504200008042000098420000B0420000C8420000E0420000F842000008430000A040000088410000E84100002442000054420000824200009A420000B2420000CA420000E2420000FA42000009430000C040000090410000F04100002842000058420000844200009C420000B4420000CC420000E4420000FC4200000A430000E040000098410000F84100002C4200005C420000864200009E420000B6420000CE420000E6420000FE4200000B43000000410000A041000000420000304200006042000088420000A0420000B8420000D0420000E8420000004300000C43000010410000A84100000442000034420000644200008A420000A2420000BA420000D2420000EA420000014300000D43000020410000B04100000842000038420000684200008C420000A4420000BC420000D4420000EC420000024300000E43000030410000B84100000C4200003C4200006C4200008E420000A6420000BE420000D6420000EE420000034300000F43000040410000C041000010420000404200007042000090420000A8420000C0420000D8420000F0420000044300001043"> : tensor<12x12xf32>
39+
%c = stablehlo.constant dense<[[0, 0], [3, 0], [5, 0], [11, 0], [0, 1], [3, 1], [5, 1], [11, 1], [0, 2], [3, 2], [5, 2], [11, 2], [0, 4], [3, 4], [5, 4], [11, 4], [0, 11], [3, 11], [5, 11], [11, 11], [0, 10], [3, 10], [5, 10], [11, 10]]> : tensor<24x2xi64>
40+
%0 = "stablehlo.gather"(%cst, %c) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<12x12xf32>, tensor<24x2xi64>) -> tensor<24xf32>
41+
%1 = stablehlo.convert %0 : (tensor<24xf32>) -> tensor<24xf64>
42+
%2 = stablehlo.reshape %1 : (tensor<24xf64>) -> tensor<6x4xf64>
43+
%3 = stablehlo.multiply %2, %arg0 : tensor<6x4xf64>
44+
return %3 : tensor<6x4xf64>
45+
}
46+
47+
// CHECK: func.func @gather_constprop3(%arg0: tensor<6x4xf64>) -> tensor<6x4xf64> {
48+
// CHECK-NEXT{LITERAL}: %cst = stablehlo.constant dense<[[1.000000e+00, 4.000000e+00, 6.000000e+00, 1.200000e+01], [1.300000e+01, 1.600000e+01, 1.800000e+01, 2.400000e+01], [2.500000e+01, 2.800000e+01, 3.000000e+01, 3.600000e+01], [4.900000e+01, 5.200000e+01, 5.400000e+01, 6.000000e+01], [1.330000e+02, 1.360000e+02, 1.380000e+02, 1.440000e+02], [1.210000e+02, 1.240000e+02, 1.260000e+02, 1.320000e+02]]> : tensor<6x4xf64>
49+
// CHECK-NEXT: %0 = stablehlo.multiply %cst, %arg0 : tensor<6x4xf64>
50+
// CHECK-NEXT: return %0 : tensor<6x4xf64>
51+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)