Skip to content

Commit 0e5579b

Browse files
committed
feat: expand mul scatter simply to support setindex ones
1 parent 418f80e commit 0e5579b

File tree

8 files changed

+202
-55
lines changed

8 files changed

+202
-55
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 96 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22704,36 +22704,107 @@ struct ScatterMultiplySimplify final
2270422704
}
2270522705
}
2270622706

22707-
auto status =
22708-
detectConstantSetindexScatterOp(scatterOp, /*allowedMultipleUses*/
22709-
false,
22710-
/*onlyConstantZerosAllowed*/
22711-
true, nullptr);
22712-
if (!status.ok())
22707+
if (!scatterOp.getUniqueIndices()) {
22708+
return failure();
22709+
}
22710+
22711+
bool isAllZeros = false, isAllOnes = false;
22712+
22713+
SplatElementsAttr constSetIndexValue = nullptr;
22714+
auto status = detectConstantSetindexScatterOp(
22715+
scatterOp, /*allowedMultipleUses*/ false,
22716+
[&isAllZeros, &isAllOnes](mlir::Value input) {
22717+
isAllZeros = matchPattern(input, m_AnyZeroFloat()) ||
22718+
matchPattern(input, m_Zero());
22719+
if (!isAllZeros) {
22720+
isAllOnes = matchPattern(input, m_OneFloat()) ||
22721+
matchPattern(input, m_One());
22722+
}
22723+
return isAllZeros || isAllOnes;
22724+
},
22725+
constSetIndexValue);
22726+
if (!status.ok()) {
2271322727
return rewriter.notifyMatchFailure(op, status.message());
22728+
}
2271422729

2271522730
SmallVector<int64_t> sliceSizes = computeGatherSliceSizes(scatterOp);
2271622731

22717-
auto gatheredValues = stablehlo::GatherOp::create(
22718-
rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(),
22719-
getGatherDims(rewriter.getContext(),
22720-
scatterOp.getScatterDimensionNumbers()),
22721-
rewriter.getDenseI64ArrayAttr(sliceSizes),
22722-
scatterOp.getIndicesAreSortedAttr());
22732+
if (isAllZeros) { // non scattered values before zeros
22733+
auto gatheredValues = stablehlo::GatherOp::create(
22734+
rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(),
22735+
getGatherDims(rewriter.getContext(),
22736+
scatterOp.getScatterDimensionNumbers()),
22737+
rewriter.getDenseI64ArrayAttr(sliceSizes),
22738+
scatterOp.getIndicesAreSortedAttr());
22739+
22740+
Value mulRhs;
22741+
if (constSetIndexValue) {
22742+
mulRhs = stablehlo::ConstantOp::create(
22743+
rewriter, op.getLoc(),
22744+
constSetIndexValue.resizeSplat(
22745+
cast<ShapedType>(gatheredValues.getType())));
22746+
} else {
22747+
mulRhs = scatterOp.getUpdates()[0];
22748+
}
22749+
auto newUpdates =
22750+
stablehlo::MulOpCreate(rewriter, op.getLoc(), gatheredValues, mulRhs);
2272322751

22724-
auto newUpdates = stablehlo::MulOp::create(
22725-
rewriter, op.getLoc(), gatheredValues, scatterOp.getUpdates()[0]);
22752+
auto newScatterOp = stablehlo::ScatterOp::create(
22753+
rewriter, op.getLoc(), scatterOp.getResultTypes(),
22754+
scatterOp.getInputs(), scatterOp.getScatterIndices(),
22755+
ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(),
22756+
scatterOp.getIndicesAreSortedAttr(),
22757+
scatterOp.getUniqueIndicesAttr());
2272622758

22727-
auto newScatterOp = stablehlo::ScatterOp::create(
22728-
rewriter, op.getLoc(), scatterOp.getResultTypes(),
22729-
scatterOp.getInputs(), scatterOp.getScatterIndices(),
22730-
ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(),
22731-
scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
22732-
newScatterOp.getUpdateComputation().takeBody(
22733-
scatterOp.getUpdateComputation());
22734-
rewriter.replaceOp(op, newScatterOp);
22759+
auto &updateRegion = newScatterOp.getUpdateComputation();
22760+
auto *block = rewriter.createBlock(&updateRegion);
22761+
auto elemType = cast<RankedTensorType>(scatterOp.getResultTypes()[0])
22762+
.getElementType();
22763+
auto argType = RankedTensorType::get({}, elemType);
22764+
block->addArgument(argType, op.getLoc());
22765+
block->addArgument(argType, op.getLoc());
22766+
rewriter.setInsertionPointToStart(block);
22767+
stablehlo::ReturnOp::create(rewriter, op.getLoc(), block->getArgument(1));
2273522768

22736-
return success();
22769+
rewriter.replaceOp(op, newScatterOp);
22770+
return success();
22771+
}
22772+
22773+
if (isAllOnes) { // non-scattered values stay as is
22774+
auto newScatterOp = stablehlo::ScatterOp::create(
22775+
rewriter, op.getLoc(), scatterOp.getResultTypes(),
22776+
ValueRange(otherValue), scatterOp.getScatterIndices(),
22777+
scatterOp.getUpdates(), scatterOp.getScatterDimensionNumbersAttr(),
22778+
scatterOp.getIndicesAreSortedAttr(),
22779+
scatterOp.getUniqueIndicesAttr());
22780+
22781+
auto &updateRegion = newScatterOp.getUpdateComputation();
22782+
auto *block = rewriter.createBlock(&updateRegion);
22783+
auto elemType = cast<RankedTensorType>(scatterOp.getResultTypes()[0])
22784+
.getElementType();
22785+
auto argType = RankedTensorType::get({}, elemType);
22786+
block->addArgument(argType, op.getLoc());
22787+
block->addArgument(argType, op.getLoc());
22788+
rewriter.setInsertionPointToStart(block);
22789+
22790+
Value mulRhs;
22791+
if (constSetIndexValue) {
22792+
mulRhs = stablehlo::ConstantOp::create(
22793+
rewriter, op.getLoc(),
22794+
constSetIndexValue.resizeSplat(
22795+
RankedTensorType::get({}, elemType)));
22796+
} else {
22797+
mulRhs = block->getArgument(1);
22798+
}
22799+
auto mulOp = stablehlo::MulOp::create(rewriter, op.getLoc(),
22800+
block->getArgument(0), mulRhs);
22801+
stablehlo::ReturnOp::create(rewriter, op.getLoc(), mulOp.getResult());
22802+
22803+
rewriter.replaceOp(op, newScatterOp);
22804+
return success();
22805+
}
22806+
22807+
return failure();
2273722808
}
2273822809
};
2273922810

@@ -22807,10 +22878,9 @@ struct UnaryElementwiseScatterSimplify final
2280722878
if (!scatterOp)
2280822879
return rewriter.notifyMatchFailure(op, "not a scatter op");
2280922880

22810-
DenseElementsAttr scatterInputAttr;
2281122881
auto status = detectConstantSetindexScatterOp(
22812-
scatterOp, false, /*onlyConstantZerosAllowed*/ false,
22813-
&scatterInputAttr);
22882+
scatterOp, false,
22883+
[](mlir::Value input) { return matchPattern(input, m_Constant()); });
2281422884
if (!status.ok()) {
2281522885
return rewriter.notifyMatchFailure(op, status.message());
2281622886
}

src/enzyme_ad/jax/Passes/StructuredTensors.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,16 @@ namespace enzyme {
1212

1313
absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
1414
bool allowedMultipleUses,
15-
bool onlyConstantZerosAllowed,
16-
DenseElementsAttr *constAttr) {
15+
InputValidatorFn inputValidator) {
16+
SplatElementsAttr constSetIndexValue = nullptr;
17+
return detectConstantSetindexScatterOp(scatterOp, allowedMultipleUses,
18+
inputValidator, constSetIndexValue);
19+
}
20+
21+
absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
22+
bool allowedMultipleUses,
23+
InputValidatorFn inputValidator,
24+
SplatElementsAttr &constSetIndexValue) {
1725
if (scatterOp.getInputs().size() != 1) {
1826
return absl::UnimplementedError(
1927
"Detection not implemented for scatter op with >1 input.");
@@ -26,20 +34,18 @@ absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
2634

2735
auto checkCommonScatterOp = mlir::stablehlo::CheckCommonScatterOp(scatterOp);
2836

29-
if (!checkCommonScatterOp.isSetindexScatter) {
37+
if (!checkCommonScatterOp.isSetindexScatter &&
38+
!checkCommonScatterOp.isConstantSetindexScatter) {
3039
return absl::InvalidArgumentError("ScatterOp is not a setindex op.");
3140
}
3241

42+
if (checkCommonScatterOp.isConstantSetindexScatter) {
43+
constSetIndexValue = checkCommonScatterOp.constant;
44+
}
45+
3346
auto input = scatterOp.getInputs()[0];
34-
if (onlyConstantZerosAllowed) {
35-
if (matchPattern(input, m_AnyZeroFloat()) ||
36-
matchPattern(input, m_Zero())) {
37-
return absl::OkStatus();
38-
}
39-
} else {
40-
if (matchPattern(input, m_Constant(constAttr))) {
41-
return absl::OkStatus();
42-
}
47+
if (inputValidator(input)) {
48+
return absl::OkStatus();
4349
}
4450

4551
return absl::InvalidArgumentError(
@@ -49,7 +55,11 @@ absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
4955
// TODO: detect batched diagonal tensors
5056
absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
5157
mlir::Value *outUpdates) {
52-
auto status = detectConstantSetindexScatterOp(scatterOp, true, true, nullptr);
58+
auto status =
59+
detectConstantSetindexScatterOp(scatterOp, true, [](mlir::Value input) {
60+
return matchPattern(input, m_AnyZeroFloat()) ||
61+
matchPattern(input, m_Zero());
62+
});
5363
if (!status.ok())
5464
return status;
5565

src/enzyme_ad/jax/Passes/StructuredTensors.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,23 @@
77

88
#include "llvm/ADT/SetVector.h"
99

10+
#include <functional>
1011
#include <optional>
1112

1213
namespace mlir {
1314
namespace enzyme {
1415

16+
using InputValidatorFn = std::function<bool(mlir::Value)>;
17+
18+
19+
absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
20+
bool allowedMultipleUses,
21+
InputValidatorFn inputValidator,
22+
SplatElementsAttr &constSetIndexValue);
23+
1524
absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
1625
bool allowedMultipleUses,
17-
bool onlyConstantZerosAllowed,
18-
DenseElementsAttr *constAttr);
26+
InputValidatorFn inputValidator);
1927

2028
absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
2129
mlir::Value *outUpdates);

src/enzyme_ad/jax/Utils.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,26 +1422,44 @@ getGatherDims(mlir::MLIRContext *ctx,
14221422
scatterDimNumbers.getIndexVectorDim());
14231423
}
14241424

1425-
bool isSetindexBlock(mlir::Block *block) {
1426-
if (block->getNumArguments() != 2)
1425+
bool isSetindexBlockHelper(
1426+
mlir::Block *block,
1427+
std::function<bool(stablehlo::ReturnOp retOp, Value updateValue)> fn) {
1428+
if (block->getNumArguments() != 2) {
14271429
return false;
1428-
1429-
auto updateValue = block->getArgument(1);
1430+
}
14301431

14311432
// The block should have exactly one operation (the return)
1432-
if (block->getOperations().size() != 1)
1433+
if (block->getOperations().size() != 1) {
14331434
return false;
1435+
}
14341436

14351437
auto &returnOp = block->front();
14361438
auto stablehloReturnOp = dyn_cast<stablehlo::ReturnOp>(returnOp);
1437-
if (!stablehloReturnOp)
1439+
if (!stablehloReturnOp) {
14381440
return false;
1441+
}
14391442

1440-
if (stablehloReturnOp.getNumOperands() != 1)
1443+
if (stablehloReturnOp.getNumOperands() != 1) {
14411444
return false;
1445+
}
1446+
1447+
return fn(stablehloReturnOp, block->getArgument(1));
1448+
}
14421449

1443-
// The returned value should be the update value (second argument)
1444-
return stablehloReturnOp.getOperand(0) == updateValue;
1450+
bool isSetindexBlock(mlir::Block *block) {
1451+
return isSetindexBlockHelper(
1452+
block, [](stablehlo::ReturnOp retOp, Value updateValue) {
1453+
return retOp.getOperand(0) == updateValue;
1454+
});
1455+
}
1456+
1457+
bool isConstantSetindexBlock(mlir::Block *block,
1458+
mlir::SplatElementsAttr &constant) {
1459+
return isSetindexBlockHelper(
1460+
block, [&constant](stablehlo::ReturnOp retOp, Value updateValue) {
1461+
return matchPattern(retOp.getOperand(0), m_Constant(&constant));
1462+
});
14451463
}
14461464

14471465
SmallVector<int64_t> computeGatherSliceSizes(stablehlo::ScatterOp &scatterOp) {

src/enzyme_ad/jax/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,8 @@ getGatherDims(mlir::MLIRContext *ctx,
942942
stablehlo::ScatterDimensionNumbersAttr scatterDimNumbers);
943943

944944
bool isSetindexBlock(mlir::Block *block);
945+
bool isConstantSetindexBlock(mlir::Block *block,
946+
mlir::SplatElementsAttr &constant);
945947

946948
// rhs is only considered if commutative is false
947949
template <typename T, bool commutative, bool rhs>
@@ -1067,6 +1069,8 @@ struct CheckCommonReduceOp {
10671069
struct CheckCommonScatterOp {
10681070
public:
10691071
bool isSetindexScatter;
1072+
bool isConstantSetindexScatter;
1073+
10701074
bool isAddScatter;
10711075
bool isMinScatter;
10721076
bool isMaxScatter;
@@ -1087,6 +1091,7 @@ struct CheckCommonScatterOp {
10871091

10881092
if (!updateComputation.hasOneBlock()) {
10891093
isSetindexScatter = false;
1094+
isConstantSetindexScatter = false;
10901095
isAddScatter = false;
10911096
isMinScatter = false;
10921097
isMaxScatter = false;
@@ -1105,6 +1110,7 @@ struct CheckCommonScatterOp {
11051110

11061111
auto &block = updateComputation.front();
11071112
isSetindexScatter = isSetindexBlock(&block);
1113+
isConstantSetindexScatter = isConstantSetindexBlock(&block, constant);
11081114
isAddScatter = isOnlyOpBlock<stablehlo::AddOp, true, false>(&block);
11091115
isMulScatter = isOnlyOpBlock<stablehlo::MulOp, true, false>(&block);
11101116
isMinScatter = isOnlyOpBlock<stablehlo::MinOp, true, false>(&block);

test/lit_tests/mulscatter.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1
1111
%4 = stablehlo.reshape %2 : (tensor<6x4xi64>) -> tensor<24x1xi64>
1212
%5 = stablehlo.concatenate %3, %4, dim = 1 : (tensor<24x1xi64>, tensor<24x1xi64>) -> tensor<24x2xi64>
1313
%6 = stablehlo.subtract %5, %c : tensor<24x2xi64>
14-
%7 = "stablehlo.scatter"(%cst_0, %6, %cst) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
14+
%7 = "stablehlo.scatter"(%cst_0, %6, %cst) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
1515
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
1616
stablehlo.return %arg4 : tensor<f32>
1717
}) : (tensor<1024x1024xf32>, tensor<24x2xi64>, tensor<24xf32>) -> tensor<1024x1024xf32>
@@ -27,7 +27,7 @@ func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1
2727
// CHECK: %[[arg2_T:.*]] = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
2828
// CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[arg2_T]], %[[SCATTER_INDICES:.*]]) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<1024x1024xf32>, tensor<24x2xi64>) -> tensor<24xf32>
2929
// CHECK: %[[MUL:.*]] = stablehlo.multiply %[[GATHER]], %[[CST]] : tensor<24xf32>
30-
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[CST_1]], %[[SCATTER_INDICES]], %[[MUL]]) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
30+
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[CST_1]], %[[SCATTER_INDICES]], %[[MUL]]) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
3131
// CHECK: %[[RESULT:.*]] = stablehlo.transpose %[[SCATTER]], dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
3232
// CHECK: return %[[RESULT]] : tensor<1024x1024xf32>
3333
// CHECK: }

test/lit_tests/mulscatterones.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s
2+
3+
func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
4+
%cst = stablehlo.constant dense<2.000000e+00> : tensor<24xf32>
5+
%c = stablehlo.constant dense<1> : tensor<24x2xi64>
6+
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x1024xf32>
7+
%0 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
8+
%1 = stablehlo.concatenate %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, dim = 0 : (tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<24xi64>
9+
%2 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<6xi64>) -> tensor<6x4xi64>
10+
%3 = stablehlo.reshape %1 : (tensor<24xi64>) -> tensor<24x1xi64>
11+
%4 = stablehlo.reshape %2 : (tensor<6x4xi64>) -> tensor<24x1xi64>
12+
%5 = stablehlo.concatenate %3, %4, dim = 1 : (tensor<24x1xi64>, tensor<24x1xi64>) -> tensor<24x2xi64>
13+
%6 = stablehlo.subtract %5, %c : tensor<24x2xi64>
14+
%7 = "stablehlo.scatter"(%cst_0, %6, %cst) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
15+
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
16+
stablehlo.return %arg4 : tensor<f32>
17+
}) : (tensor<1024x1024xf32>, tensor<24x2xi64>, tensor<24xf32>) -> tensor<1024x1024xf32>
18+
%8 = stablehlo.multiply %7, %0 : tensor<1024x1024xf32>
19+
%9 = stablehlo.transpose %8, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
20+
return %9 : tensor<1024x1024xf32>
21+
}
22+
23+
// CHECK: func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
24+
// CHECK: %[[S_CST:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<f32>
25+
// CHECK: %[[CST:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<24xf32>
26+
// CHECK: %[[CST_0:.*]] = stablehlo.constant dense<1> : tensor<24x2xi64>
27+
// CHECK: %[[arg2_T:.*]] = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
28+
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[arg2_T]], %[[SCATTER_INDICES:.*]], %[[CST]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
29+
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
30+
// CHECK: %[[MUL:.*]] = stablehlo.multiply %[[ARG3]], %[[S_CST]] : tensor<f32>
31+
// CHECK: stablehlo.return %[[MUL]] : tensor<f32>
32+
// CHECK: })
33+
// CHECK: %[[RESULT:.*]] = stablehlo.transpose %[[SCATTER]], dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
34+
// CHECK: return %[[RESULT]] : tensor<1024x1024xf32>
35+
// CHECK: }

0 commit comments

Comments
 (0)