@@ -22704,36 +22704,107 @@ struct ScatterMultiplySimplify final
2270422704 }
2270522705 }
2270622706
22707- auto status =
22708- detectConstantSetindexScatterOp(scatterOp, /*allowedMultipleUses*/
22709- false,
22710- /*onlyConstantZerosAllowed*/
22711- true, nullptr);
22712- if (!status.ok())
22707+ if (!scatterOp.getUniqueIndices()) {
22708+ return failure();
22709+ }
22710+
22711+ bool isAllZeros = false, isAllOnes = false;
22712+
22713+ SplatElementsAttr constSetIndexValue = nullptr;
22714+ auto status = detectConstantSetindexScatterOp(
22715+ scatterOp, /*allowedMultipleUses*/ false,
22716+ [&isAllZeros, &isAllOnes](mlir::Value input) {
22717+ isAllZeros = matchPattern(input, m_AnyZeroFloat()) ||
22718+ matchPattern(input, m_Zero());
22719+ if (!isAllZeros) {
22720+ isAllOnes = matchPattern(input, m_OneFloat()) ||
22721+ matchPattern(input, m_One());
22722+ }
22723+ return isAllZeros || isAllOnes;
22724+ },
22725+ constSetIndexValue);
22726+ if (!status.ok()) {
2271322727 return rewriter.notifyMatchFailure(op, status.message());
22728+ }
2271422729
2271522730 SmallVector<int64_t> sliceSizes = computeGatherSliceSizes(scatterOp);
2271622731
22717- auto gatheredValues = stablehlo::GatherOp::create(
22718- rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(),
22719- getGatherDims(rewriter.getContext(),
22720- scatterOp.getScatterDimensionNumbers()),
22721- rewriter.getDenseI64ArrayAttr(sliceSizes),
22722- scatterOp.getIndicesAreSortedAttr());
22732+ if (isAllZeros) { // non scattered values before zeros
22733+ auto gatheredValues = stablehlo::GatherOp::create(
22734+ rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(),
22735+ getGatherDims(rewriter.getContext(),
22736+ scatterOp.getScatterDimensionNumbers()),
22737+ rewriter.getDenseI64ArrayAttr(sliceSizes),
22738+ scatterOp.getIndicesAreSortedAttr());
22739+
22740+ Value mulRhs;
22741+ if (constSetIndexValue) {
22742+ mulRhs = stablehlo::ConstantOp::create(
22743+ rewriter, op.getLoc(),
22744+ constSetIndexValue.resizeSplat(
22745+ cast<ShapedType>(gatheredValues.getType())));
22746+ } else {
22747+ mulRhs = scatterOp.getUpdates()[0];
22748+ }
22749+ auto newUpdates =
22750+ stablehlo::MulOpCreate(rewriter, op.getLoc(), gatheredValues, mulRhs);
2272322751
22724- auto newUpdates = stablehlo::MulOp::create(
22725- rewriter, op.getLoc(), gatheredValues, scatterOp.getUpdates()[0]);
22752+ auto newScatterOp = stablehlo::ScatterOp::create(
22753+ rewriter, op.getLoc(), scatterOp.getResultTypes(),
22754+ scatterOp.getInputs(), scatterOp.getScatterIndices(),
22755+ ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(),
22756+ scatterOp.getIndicesAreSortedAttr(),
22757+ scatterOp.getUniqueIndicesAttr());
2272622758
22727- auto newScatterOp = stablehlo::ScatterOp::create(
22728- rewriter, op.getLoc(), scatterOp.getResultTypes(),
22729- scatterOp.getInputs(), scatterOp.getScatterIndices(),
22730- ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(),
22731- scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
22732- newScatterOp.getUpdateComputation().takeBody(
22733- scatterOp.getUpdateComputation());
22734- rewriter.replaceOp(op, newScatterOp);
22759+ auto &updateRegion = newScatterOp.getUpdateComputation();
22760+ auto *block = rewriter.createBlock(&updateRegion);
22761+ auto elemType = cast<RankedTensorType>(scatterOp.getResultTypes()[0])
22762+ .getElementType();
22763+ auto argType = RankedTensorType::get({}, elemType);
22764+ block->addArgument(argType, op.getLoc());
22765+ block->addArgument(argType, op.getLoc());
22766+ rewriter.setInsertionPointToStart(block);
22767+ stablehlo::ReturnOp::create(rewriter, op.getLoc(), block->getArgument(1));
2273522768
22736- return success();
22769+ rewriter.replaceOp(op, newScatterOp);
22770+ return success();
22771+ }
22772+
22773+ if (isAllOnes) { // non-scattered values stay as is
22774+ auto newScatterOp = stablehlo::ScatterOp::create(
22775+ rewriter, op.getLoc(), scatterOp.getResultTypes(),
22776+ ValueRange(otherValue), scatterOp.getScatterIndices(),
22777+ scatterOp.getUpdates(), scatterOp.getScatterDimensionNumbersAttr(),
22778+ scatterOp.getIndicesAreSortedAttr(),
22779+ scatterOp.getUniqueIndicesAttr());
22780+
22781+ auto &updateRegion = newScatterOp.getUpdateComputation();
22782+ auto *block = rewriter.createBlock(&updateRegion);
22783+ auto elemType = cast<RankedTensorType>(scatterOp.getResultTypes()[0])
22784+ .getElementType();
22785+ auto argType = RankedTensorType::get({}, elemType);
22786+ block->addArgument(argType, op.getLoc());
22787+ block->addArgument(argType, op.getLoc());
22788+ rewriter.setInsertionPointToStart(block);
22789+
22790+ Value mulRhs;
22791+ if (constSetIndexValue) {
22792+ mulRhs = stablehlo::ConstantOp::create(
22793+ rewriter, op.getLoc(),
22794+ constSetIndexValue.resizeSplat(
22795+ RankedTensorType::get({}, elemType)));
22796+ } else {
22797+ mulRhs = block->getArgument(1);
22798+ }
22799+ auto mulOp = stablehlo::MulOp::create(rewriter, op.getLoc(),
22800+ block->getArgument(0), mulRhs);
22801+ stablehlo::ReturnOp::create(rewriter, op.getLoc(), mulOp.getResult());
22802+
22803+ rewriter.replaceOp(op, newScatterOp);
22804+ return success();
22805+ }
22806+
22807+ return failure();
2273722808 }
2273822809};
2273922810
@@ -22807,10 +22878,9 @@ struct UnaryElementwiseScatterSimplify final
2280722878 if (!scatterOp)
2280822879 return rewriter.notifyMatchFailure(op, "not a scatter op");
2280922880
22810- DenseElementsAttr scatterInputAttr;
2281122881 auto status = detectConstantSetindexScatterOp(
22812- scatterOp, false, /*onlyConstantZerosAllowed*/ false,
22813- &scatterInputAttr );
22882+ scatterOp, false,
22883+ [](mlir::Value input) { return matchPattern(input, m_Constant()); } );
2281422884 if (!status.ok()) {
2281522885 return rewriter.notifyMatchFailure(op, status.message());
2281622886 }
0 commit comments