@@ -5685,6 +5685,52 @@ struct UnaryConstProp final
56855685 }
56865686};
56875687
5688+ struct ClampConstProp final
5689+ : CheckedOpRewritePattern<stablehlo::ClampOp, ClampConstProp> {
5690+ using CheckedOpRewritePattern::CheckedOpRewritePattern;
5691+
5692+ LogicalResult matchAndRewriteImpl(stablehlo::ClampOp op,
5693+ PatternRewriter &rewriter) const {
5694+ DenseElementsAttr minAttr, inputAttr, maxAttr;
5695+ if (!matchPattern(op.getMin(), m_Constant(&minAttr)) ||
5696+ !matchPattern(op.getOperand(), m_Constant(&inputAttr)) ||
5697+ !matchPattern(op.getMax(), m_Constant(&maxAttr)))
5698+ return failure();
5699+
5700+ // TODO: for only min or max with input being constant we can convert this
5701+ // to a min/max op
5702+ stablehlo::Tensor minTen, maxTen, inputTen;
5703+ bool splattedVersion = false;
5704+ RankedTensorType ty = cast<RankedTensorType>(op->getResultTypes()[0]);
5705+ if (minAttr.isSplat() && maxAttr.isSplat() && inputAttr.isSplat()) {
5706+ splattedVersion = true;
5707+ ty = RankedTensorType::get(
5708+ {}, cast<ShapedType>(op->getResultTypes()[0]).getElementType());
5709+ auto inputTy = RankedTensorType::get(
5710+ {}, cast<ShapedType>(op->getOperand(0).getType()).getElementType());
5711+ minTen = stablehlo::makeTensor(minAttr.resizeSplat(inputTy));
5712+ maxTen = stablehlo::makeTensor(maxAttr.resizeSplat(inputTy));
5713+ inputTen = stablehlo::makeTensor(inputAttr.resizeSplat(inputTy));
5714+ } else {
5715+ minTen = stablehlo::constantOp(minAttr);
5716+ maxTen = stablehlo::constantOp(maxAttr);
5717+ inputTen = stablehlo::constantOp(inputAttr);
5718+ }
5719+
5720+ auto out =
5721+ fromTensor(clampOp(inputTen, minTen, maxTen, cast<ShapedType>(ty)));
5722+
5723+ if (splattedVersion) {
5724+ out = out.resizeSplat(cast<ShapedType>(op->getResultTypes()[0]));
5725+ }
5726+ // Replace with new constant op containing the computed result
5727+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
5728+ op, op->getResultTypes()[0], out);
5729+
5730+ return success();
5731+ }
5732+ };
5733+
56885734struct ChloInfConstProp final
56895735 : CheckedOpRewritePattern<chlo::IsInfOp, ChloInfConstProp> {
56905736 using CheckedOpRewritePattern::CheckedOpRewritePattern;
@@ -23136,6 +23182,156 @@ struct CaseToIf : public CheckedOpRewritePattern<stablehlo::CaseOp, CaseToIf> {
2313623182 }
2313723183};
2313823184
23185+ struct DUSToDynamicPad
23186+ : public CheckedOpRewritePattern<stablehlo::DynamicUpdateSliceOp,
23187+ DUSToDynamicPad> {
23188+ using CheckedOpRewritePattern<stablehlo::DynamicUpdateSliceOp,
23189+ DUSToDynamicPad>::CheckedOpRewritePattern;
23190+
23191+ LogicalResult matchAndRewriteImpl(stablehlo::DynamicUpdateSliceOp op,
23192+ PatternRewriter &rewriter) const {
23193+ auto operand = op.getOperand();
23194+ auto update = op.getUpdate();
23195+ auto indices = op.getStartIndices();
23196+
23197+ for (auto [i, index] : llvm::enumerate(indices)) {
23198+ if (!matchPattern(index, m_Constant())) {
23199+ return rewriter.notifyMatchFailure(
23200+ op, "not all indices are constant. currently we don't support this "
23201+ "case");
23202+ }
23203+ }
23204+
23205+ Value scalarOperand = getScalarPadValue(rewriter, operand);
23206+ if (!scalarOperand)
23207+ return rewriter.notifyMatchFailure(op, "operand is not a scalar pad");
23208+
23209+ auto updateShape = cast<RankedTensorType>(update.getType()).getShape();
23210+ auto operandShape = cast<RankedTensorType>(operand.getType()).getShape();
23211+
23212+ SmallVector<Value> edgePaddingLowValues, edgePaddingHighValues;
23213+ for (auto [i, index] : llvm::enumerate(indices)) {
23214+ auto cType = RankedTensorType::get(
23215+ {}, cast<RankedTensorType>(index.getType()).getElementType());
23216+ auto clampedIndex = rewriter.create<stablehlo::ClampOp>(
23217+ op.getLoc(),
23218+ rewriter.create<stablehlo::ConstantOp>(
23219+ op.getLoc(), cType, cast<ElementsAttr>(makeAttr(cType, 0))),
23220+ index,
23221+ rewriter.create<stablehlo::ConstantOp>(
23222+ op.getLoc(), cType,
23223+ cast<ElementsAttr>(
23224+ makeAttr(cType, operandShape[i] - updateShape[i]))));
23225+
23226+ auto reshapedIndex = rewriter.create<stablehlo::ReshapeOp>(
23227+ op.getLoc(),
23228+ RankedTensorType::get(
23229+ {1}, cast<RankedTensorType>(index.getType()).getElementType()),
23230+ clampedIndex);
23231+ edgePaddingLowValues.push_back(reshapedIndex.getResult());
23232+
23233+ auto iType = RankedTensorType::get(
23234+ {1}, cast<RankedTensorType>(index.getType()).getElementType());
23235+ auto tmp = rewriter.create<stablehlo::ConstantOp>(
23236+ op.getLoc(), iType,
23237+ cast<ElementsAttr>(
23238+ makeAttr(iType, operandShape[i] - updateShape[i])));
23239+ auto paddingHigh = rewriter.create<stablehlo::SubtractOp>(
23240+ op.getLoc(), tmp, reshapedIndex);
23241+ edgePaddingHighValues.push_back(paddingHigh);
23242+ }
23243+
23244+ auto edgePaddingLow = rewriter.create<stablehlo::ConcatenateOp>(
23245+ op.getLoc(), edgePaddingLowValues, 0);
23246+ auto edgePaddingHigh = rewriter.create<stablehlo::ConcatenateOp>(
23247+ op.getLoc(), edgePaddingHighValues, 0);
23248+ auto interiorPadding = rewriter.create<stablehlo::ConstantOp>(
23249+ op.getLoc(), edgePaddingLow.getType(),
23250+ cast<ElementsAttr>(makeAttr(edgePaddingLow.getType(), 0)));
23251+
23252+ rewriter.replaceOpWithNewOp<stablehlo::DynamicPadOp>(
23253+ op, op.getType(), update, scalarOperand, edgePaddingLow,
23254+ edgePaddingHigh, interiorPadding);
23255+ return success();
23256+ }
23257+
23258+ private:
23259+ Value getScalarPadValue(PatternRewriter &rewriter, Value operand) const {
23260+ Value scalarOperand = getScalarPadValueViaBcastInDim(rewriter, operand);
23261+ if (scalarOperand)
23262+ return scalarOperand;
23263+
23264+ scalarOperand = getScalarPadValueViaSplattedConstant(rewriter, operand);
23265+ if (scalarOperand)
23266+ return scalarOperand;
23267+
23268+ return nullptr;
23269+ }
23270+
23271+ Value getScalarPadValueViaBcastInDim(PatternRewriter &rewriter,
23272+ Value operand) const {
23273+ auto bcastInDimOp = operand.getDefiningOp<stablehlo::BroadcastInDimOp>();
23274+ if (!bcastInDimOp)
23275+ return nullptr;
23276+
23277+ auto bcastOperand = bcastInDimOp.getOperand();
23278+ auto bcastOperandType = cast<RankedTensorType>(bcastOperand.getType());
23279+ if (bcastOperandType.getRank() != 0)
23280+ return nullptr;
23281+
23282+ return bcastOperand;
23283+ }
23284+
23285+ Value getScalarPadValueViaSplattedConstant(PatternRewriter &rewriter,
23286+ Value operand) const {
23287+ SplatElementsAttr splatAttr;
23288+ if (!matchPattern(operand, m_Constant(&splatAttr)))
23289+ return nullptr;
23290+
23291+ return rewriter.create<stablehlo::ConstantOp>(
23292+ operand.getLoc(), splatAttr.getSplatValue<Attribute>());
23293+ }
23294+ };
23295+
23296+ struct DynamicPadToPad
23297+ : public CheckedOpRewritePattern<stablehlo::DynamicPadOp, DynamicPadToPad> {
23298+ using CheckedOpRewritePattern<stablehlo::DynamicPadOp,
23299+ DynamicPadToPad>::CheckedOpRewritePattern;
23300+
23301+ LogicalResult matchAndRewriteImpl(stablehlo::DynamicPadOp op,
23302+ PatternRewriter &rewriter) const {
23303+ auto operand = op.getOperand();
23304+ auto paddingValue = op.getPaddingValue();
23305+ auto edgePaddingLow = op.getEdgePaddingLow();
23306+ auto edgePaddingHigh = op.getEdgePaddingHigh();
23307+ auto interiorPadding = op.getInteriorPadding();
23308+
23309+ DenseIntElementsAttr edgePaddingLowAttr, edgePaddingHighAttr,
23310+ interiorPaddingAttr;
23311+ if (!matchPattern(edgePaddingLow, m_Constant(&edgePaddingLowAttr)) ||
23312+ !matchPattern(edgePaddingHigh, m_Constant(&edgePaddingHighAttr)) ||
23313+ !matchPattern(interiorPadding, m_Constant(&interiorPaddingAttr)))
23314+ return rewriter.notifyMatchFailure(op, "edge padding is not a constant");
23315+
23316+ rewriter.replaceOpWithNewOp<stablehlo::PadOp>(
23317+ op, op.getType(), operand, paddingValue,
23318+ convertToDenseI64ArrayAttr(edgePaddingLowAttr),
23319+ convertToDenseI64ArrayAttr(edgePaddingHighAttr),
23320+ convertToDenseI64ArrayAttr(interiorPaddingAttr));
23321+ return success();
23322+ }
23323+
23324+ private:
23325+ DenseI64ArrayAttr
23326+ convertToDenseI64ArrayAttr(DenseIntElementsAttr attr) const {
23327+ auto values = attr.getValues<APInt>();
23328+ llvm::SmallVector<int64_t> denseValues;
23329+ for (auto value : values)
23330+ denseValues.push_back(value.getSExtValue());
23331+ return DenseI64ArrayAttr::get(attr.getContext(), denseValues);
23332+ }
23333+ };
23334+
2313923335/////////////// End Imported from stablehlo
2314023336
2314123337// clang-format off
@@ -23484,7 +23680,7 @@ struct EnzymeHLOOptPass
2348423680 BinaryConstProp<stablehlo::SubtractOp, stablehlo::subtractOp>,
2348523681 BinaryConstProp<stablehlo::XorOp, stablehlo::xorOp>>(context);
2348623682
23487- patterns.add<GatherConstProp>(context);
23683+ patterns.add<GatherConstProp, ClampConstProp >(context);
2348823684
2348923685 patterns.add<BinaryOpTransposeSimplify<stablehlo::AddOp>,
2349023686 BinaryOpTransposeSimplify<stablehlo::SubtractOp>,
@@ -23747,7 +23943,9 @@ struct EnzymeHLOOptPass
2374723943 MulReduceSliceFusion,
2374823944 MinReduceSliceFusion,
2374923945 MaxReduceSliceFusion,
23750- CaseToIf
23946+ CaseToIf,
23947+ DUSToDynamicPad,
23948+ DynamicPadToPad
2375123949 >(context);
2375223950
2375323951 patterns.add<
0 commit comments