@@ -6321,6 +6321,70 @@ struct TGammaConstProp final
63216321 }
63226322};
63236323
6324+ struct LGammaConstProp final
6325+ : CheckedOpRewritePattern<enzymexla::LGammaOp, LGammaConstProp> {
6326+ using CheckedOpRewritePattern::CheckedOpRewritePattern;
6327+
6328+ LogicalResult matchAndRewriteImpl(enzymexla::LGammaOp op,
6329+ PatternRewriter &rewriter) const {
6330+ DenseElementsAttr inputAttr;
6331+ if (!matchPattern(op.getOperand(), m_Constant(&inputAttr)))
6332+ return failure();
6333+
6334+ auto resultType = cast<ShapedType>(op.getType());
6335+ auto floatTy = dyn_cast<FloatType>(resultType.getElementType());
6336+ if (!floatTy)
6337+ return failure();
6338+
6339+ const auto &sem = floatTy.getFloatSemantics();
6340+ SmallVector<APFloat> results;
6341+ for (auto val : inputAttr.getValues<APFloat>()) {
6342+ double x = val.convertToDouble();
6343+ double res = std::lgamma(x);
6344+ bool losesInfo;
6345+ APFloat apRes(res);
6346+ apRes.convert(sem, APFloat::rmNearestTiesToEven, &losesInfo);
6347+ results.push_back(apRes);
6348+ }
6349+
6350+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
6351+ op, DenseElementsAttr::get(resultType, results));
6352+ return success();
6353+ }
6354+ };
6355+
6356+ struct CHLOLGammaConstProp final
6357+ : CheckedOpRewritePattern<chlo::LgammaOp, CHLOLGammaConstProp> {
6358+ using CheckedOpRewritePattern::CheckedOpRewritePattern;
6359+
6360+ LogicalResult matchAndRewriteImpl(chlo::LgammaOp op,
6361+ PatternRewriter &rewriter) const {
6362+ DenseElementsAttr inputAttr;
6363+ if (!matchPattern(op.getOperand(), m_Constant(&inputAttr)))
6364+ return failure();
6365+
6366+ auto resultType = cast<ShapedType>(op.getType());
6367+ auto floatTy = dyn_cast<FloatType>(resultType.getElementType());
6368+ if (!floatTy)
6369+ return failure();
6370+
6371+ const auto &sem = floatTy.getFloatSemantics();
6372+ SmallVector<APFloat> results;
6373+ for (auto val : inputAttr.getValues<APFloat>()) {
6374+ double x = val.convertToDouble();
6375+ double res = std::lgamma(x);
6376+ bool losesInfo;
6377+ APFloat apRes(res);
6378+ apRes.convert(sem, APFloat::rmNearestTiesToEven, &losesInfo);
6379+ results.push_back(apRes);
6380+ }
6381+
6382+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
6383+ op, DenseElementsAttr::get(resultType, results));
6384+ return success();
6385+ }
6386+ };
6387+
63246388struct DynamicUpdateSliceConstProp final
63256389 : CheckedOpRewritePattern<stablehlo::DynamicUpdateSliceOp,
63266390 DynamicUpdateSliceConstProp> {
@@ -34987,11 +35051,12 @@ struct EnzymeHLOOptPass
3498735051 SliceOfUpdateWithoutCorners, SliceElementwise, SliceReshapeElementwise,
3498835052 DynamicSliceElementwise, SlicePad, SliceReshapePad, ReshapeSliceReshape,
3498935053 DotReshapeDot, ChloInfConstProp, GammaConstProp, TGammaConstProp,
34990- ConcatFuse, ConcatToBroadcast, PadPad, PadReshapePad,
34991- ConcatPushBinop<stablehlo::AddOp>, ConcatPushBinop<stablehlo::MulOp>,
34992- ScatterToDynamicUpdateSlice, ReduceConcat, ConcatSlice, ConcatMultiPad,
34993- ConcatWrap, WidenWrap, WidenExtend, ConcatConcatAxisSwap, SliceConcat,
34994- SliceIf, SliceReshapeConcat, BinBroadcastSplat<stablehlo::AddOp>,
35054+ LGammaConstProp, CHLOLGammaConstProp, ConcatFuse, ConcatToBroadcast,
35055+ PadPad, PadReshapePad, ConcatPushBinop<stablehlo::AddOp>,
35056+ ConcatPushBinop<stablehlo::MulOp>, ScatterToDynamicUpdateSlice,
35057+ ReduceConcat, ConcatSlice, ConcatMultiPad, ConcatWrap, WidenWrap,
35058+ WidenExtend, ConcatConcatAxisSwap, SliceConcat, SliceIf,
35059+ SliceReshapeConcat, BinBroadcastSplat<stablehlo::AddOp>,
3499535060 BinBroadcastSplat<stablehlo::SubtractOp>,
3499635061 BinBroadcastSplat<stablehlo::DivOp>,
3499735062 BinBroadcastSplat<stablehlo::MulOp>, RotatePad, ConjReal>(context);
0 commit comments