Skip to content

Commit 67baf2f

Browse files
authored
Add LGammaOp (#2361)
* Add lgamma * Update ArithRaising.cpp * CHLO Derivative * Update EnzymeHLOOpt.cpp
1 parent 8ebb08b commit 67baf2f

File tree

12 files changed

+150
-6
lines changed

12 files changed

+150
-6
lines changed

src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,20 @@ def TGammaOp: EnzymeXLA_Op<"ml.tgamma", [Pure, SameOperandsAndResultType, Elemen
12101210
}];
12111211
}
12121212

1213+
def LGammaOp: EnzymeXLA_Op<"ml.lgamma", [Pure, SameOperandsAndResultType, Elementwise]> {
1214+
let summary = "Computes the log-gamma function";
1215+
1216+
let arguments = (ins
1217+
FloatLike:$input
1218+
);
1219+
1220+
let results = (outs FloatLike:$result);
1221+
1222+
let assemblyFormat = [{
1223+
$input attr-dict `:` functional-type($input, results)
1224+
}];
1225+
}
1226+
12131227
def SoftplusOp: EnzymeXLA_Op<"ml.softplus", [Pure, SameOperandsAndResultType, Elementwise]> {
12141228
let summary = "Computes the Softplus activation function";
12151229

src/enzyme_ad/jax/Implementations/CHLODerivatives.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ def : HLODerivative<"PolygammaOp", (Op $x, $n),
137137
]
138138
>;
139139

140+
def : HLODerivative<"LgammaOp", (Op $x),
141+
[(Mul (DiffeRet), (Digamma $x))]
142+
>;
143+
140144
def : HLODerivative<"SinhOp", (Op $x), [(Mul (DiffeRet), (Cosh $x))]>;
141145

142146
def : HLODerivative<"TanOp", (Op $x), [

src/enzyme_ad/jax/Implementations/EnzymeXLADerivatives.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def Pad : Inst<"PadOp", "stablehlo">;
5151
def Rotate : Inst<"RotateOp", "enzymexla">;
5252
def Extend : Inst<"ExtendOp", "enzymexla">;
5353
def CHLO_Polygamma: Inst<"PolygammaOp", "chlo">;
54+
def CHLO_Digamma: Inst<"DigammaOp", "chlo">;
5455

5556
def GT : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "ComparisonDirection::GT">;
5657
def EQ : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "ComparisonDirection::EQ">;
@@ -92,6 +93,12 @@ def : EnzymeXLADerivative<"SoftplusOp", (Op $input),
9293
(Softplus (Shadow $input))
9394
>;
9495

96+
def : EnzymeXLADerivative<"LGammaOp", (Op $input),
97+
[
98+
(CheckedMul (DiffeRet), (CHLO_Digamma $input))
99+
]
100+
>;
101+
95102
def : EnzymeXLADerivative<"TGammaOp", (Op $input),
96103
[
97104
(CheckedMul (DiffeRet), (CheckedMul (TGamma $input), (CHLO_Polygamma (HLOConstantFP<"0"> $input), $input)))

src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2858,7 +2858,7 @@ tryRaisingOpToStableHLO(Operation *op, IRMapping &mapping, OpBuilder &builder,
28582858
arith::FPToUIOp, arith::TruncFOp, arith::ExtFOp, math::SqrtOp,
28592859
math::RsqrtOp, math::CbrtOp, math::LogOp, math::ExpOp, math::AbsFOp,
28602860
math::AbsIOp, math::IsNaNOp, math::AtanOp, arith::BitcastOp,
2861-
enzymexla::TGammaOp, math::ErfOp>(op)) {
2861+
enzymexla::TGammaOp, enzymexla::LGammaOp, math::ErfOp>(op)) {
28622862
assert(op->getNumOperands() == 1 && op->getNumResults() == 1);
28632863

28642864
auto operand = op->getOperand(0);

src/enzyme_ad/jax/Passes/ArithRaising.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ struct ArithRaisingPass
124124
RAISE_UNARY(math::FloorOp, stablehlo::FloorOp, mhlo::FloorOp);
125125
RAISE_UNARY(math::ErfOp, chlo::ErfOp, chlo::ErfOp);
126126
RAISE_UNARY(arith::NegFOp, stablehlo::NegOp, mhlo::NegOp);
127+
RAISE_UNARY(enzymexla::LGammaOp, chlo::LgammaOp, chlo::LgammaOp);
127128

128129
#undef RAISE_UNARY
129130

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6321,6 +6321,70 @@ struct TGammaConstProp final
63216321
}
63226322
};
63236323

6324+
struct LGammaConstProp final
6325+
: CheckedOpRewritePattern<enzymexla::LGammaOp, LGammaConstProp> {
6326+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
6327+
6328+
LogicalResult matchAndRewriteImpl(enzymexla::LGammaOp op,
6329+
PatternRewriter &rewriter) const {
6330+
DenseElementsAttr inputAttr;
6331+
if (!matchPattern(op.getOperand(), m_Constant(&inputAttr)))
6332+
return failure();
6333+
6334+
auto resultType = cast<ShapedType>(op.getType());
6335+
auto floatTy = dyn_cast<FloatType>(resultType.getElementType());
6336+
if (!floatTy)
6337+
return failure();
6338+
6339+
const auto &sem = floatTy.getFloatSemantics();
6340+
SmallVector<APFloat> results;
6341+
for (auto val : inputAttr.getValues<APFloat>()) {
6342+
double x = val.convertToDouble();
6343+
double res = std::lgamma(x);
6344+
bool losesInfo;
6345+
APFloat apRes(res);
6346+
apRes.convert(sem, APFloat::rmNearestTiesToEven, &losesInfo);
6347+
results.push_back(apRes);
6348+
}
6349+
6350+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
6351+
op, DenseElementsAttr::get(resultType, results));
6352+
return success();
6353+
}
6354+
};
6355+
6356+
struct CHLOLGammaConstProp final
6357+
: CheckedOpRewritePattern<chlo::LgammaOp, CHLOLGammaConstProp> {
6358+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
6359+
6360+
LogicalResult matchAndRewriteImpl(chlo::LgammaOp op,
6361+
PatternRewriter &rewriter) const {
6362+
DenseElementsAttr inputAttr;
6363+
if (!matchPattern(op.getOperand(), m_Constant(&inputAttr)))
6364+
return failure();
6365+
6366+
auto resultType = cast<ShapedType>(op.getType());
6367+
auto floatTy = dyn_cast<FloatType>(resultType.getElementType());
6368+
if (!floatTy)
6369+
return failure();
6370+
6371+
const auto &sem = floatTy.getFloatSemantics();
6372+
SmallVector<APFloat> results;
6373+
for (auto val : inputAttr.getValues<APFloat>()) {
6374+
double x = val.convertToDouble();
6375+
double res = std::lgamma(x);
6376+
bool losesInfo;
6377+
APFloat apRes(res);
6378+
apRes.convert(sem, APFloat::rmNearestTiesToEven, &losesInfo);
6379+
results.push_back(apRes);
6380+
}
6381+
6382+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
6383+
op, DenseElementsAttr::get(resultType, results));
6384+
return success();
6385+
}
6386+
};
6387+
63246388
struct DynamicUpdateSliceConstProp final
63256389
: CheckedOpRewritePattern<stablehlo::DynamicUpdateSliceOp,
63266390
DynamicUpdateSliceConstProp> {
@@ -34987,11 +35051,12 @@ struct EnzymeHLOOptPass
3498735051
SliceOfUpdateWithoutCorners, SliceElementwise, SliceReshapeElementwise,
3498835052
DynamicSliceElementwise, SlicePad, SliceReshapePad, ReshapeSliceReshape,
3498935053
DotReshapeDot, ChloInfConstProp, GammaConstProp, TGammaConstProp,
34990-
ConcatFuse, ConcatToBroadcast, PadPad, PadReshapePad,
34991-
ConcatPushBinop<stablehlo::AddOp>, ConcatPushBinop<stablehlo::MulOp>,
34992-
ScatterToDynamicUpdateSlice, ReduceConcat, ConcatSlice, ConcatMultiPad,
34993-
ConcatWrap, WidenWrap, WidenExtend, ConcatConcatAxisSwap, SliceConcat,
34994-
SliceIf, SliceReshapeConcat, BinBroadcastSplat<stablehlo::AddOp>,
35054+
LGammaConstProp, CHLOLGammaConstProp, ConcatFuse, ConcatToBroadcast,
35055+
PadPad, PadReshapePad, ConcatPushBinop<stablehlo::AddOp>,
35056+
ConcatPushBinop<stablehlo::MulOp>, ScatterToDynamicUpdateSlice,
35057+
ReduceConcat, ConcatSlice, ConcatMultiPad, ConcatWrap, WidenWrap,
35058+
WidenExtend, ConcatConcatAxisSwap, SliceConcat, SliceIf,
35059+
SliceReshapeConcat, BinBroadcastSplat<stablehlo::AddOp>,
3499535060
BinBroadcastSplat<stablehlo::SubtractOp>,
3499635061
BinBroadcastSplat<stablehlo::DivOp>,
3499735062
BinBroadcastSplat<stablehlo::MulOp>, RotatePad, ConjReal>(context);

src/enzyme_ad/jax/Passes/LibDeviceFuncsRaisingPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,8 @@ void mlir::enzyme::populateLibDeviceFuncsToOpsPatterns(
10501050
"__nv_truncf");
10511051
populateOpPatterns<enzymexla::TGammaOp>(converter, patterns, "__nv_tgamma",
10521052
"__nv_tgammaf");
1053+
populateOpPatterns<enzymexla::LGammaOp>(converter, patterns, "__nv_lgamma",
1054+
"__nv_lgammaf");
10531055
}
10541056

10551057
void populateLLVMToMathPatterns(MLIRContext *context,

src/enzyme_ad/jax/Passes/LowerEnzymeXLAML.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ struct LowerTGammaOpToStablehlo : public OpRewritePattern<enzymexla::TGammaOp> {
8080
}
8181
};
8282

83+
struct LowerLGammaOpToStablehlo : public OpRewritePattern<enzymexla::LGammaOp> {
84+
using OpRewritePattern<enzymexla::LGammaOp>::OpRewritePattern;
85+
86+
LogicalResult matchAndRewrite(enzymexla::LGammaOp op,
87+
PatternRewriter &rewriter) const override {
88+
auto loc = op.getLoc();
89+
auto operand = op.getOperand();
90+
auto result = stablehlo::materializeLgamma(rewriter, loc, operand);
91+
rewriter.replaceOp(op, result);
92+
return success();
93+
}
94+
};
95+
8396
struct LowerGeluOpToStablehlo : public OpRewritePattern<enzymexla::GeluOp> {
8497
using OpRewritePattern<enzymexla::GeluOp>::OpRewritePattern;
8598

@@ -230,6 +243,7 @@ struct LowerEnzymeXLAMLPass
230243
patterns.add<LowerGeluOpToStablehlo>(context);
231244
patterns.add<LowerSoftplusOpToStablehlo>(context);
232245
patterns.add<LowerTGammaOpToStablehlo>(context);
246+
patterns.add<LowerLGammaOpToStablehlo>(context);
233247

234248
GreedyRewriteConfig config;
235249
config.enableFolding();

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,10 @@ def ApplyTGammaConstProp : EnzymeHLOPatternOp<
257257
"tgamma_const_prop">{
258258
let patterns = ["TGammaConstProp"];
259259
}
260+
def ApplyLGammaConstProp : EnzymeHLOPatternOp<
261+
"lgamma_const_prop">{
262+
let patterns = ["LGammaConstProp", "CHLOLGammaConstProp"];
263+
}
260264
def ApplySoftplusConstProp : EnzymeHLOPatternOp<
261265
"softplus_const_prop">{
262266
let patterns = ["UnaryConstProp<enzymexla::SoftplusOp,softplusOp>"];

src/enzyme_ad/jax/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def optimization_passes(
393393
"relu_const_prop",
394394
"gelu_const_prop",
395395
"tgamma_const_prop",
396+
"lgamma_const_prop",
396397
"softplus_const_prop",
397398
# binary constant propagation
398399
"add_const_prop",

0 commit comments

Comments
 (0)