@@ -3562,6 +3562,51 @@ class DecomposeAtenBernoulliTensorOp
35623562};
35633563} // namespace
35643564
3565+ namespace {
3566+ // Decompose exponential() to do inverse transform sampling.
3567+ // - https://en.wikipedia.org/wiki/Inverse_transform_sampling
3568+ // With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus,
3569+ // exponential() = - ln(1 - uniform(0, 1)) / lambda.
3570+ class DecomposeAtenExponentialOp : public OpRewritePattern <AtenExponentialOp> {
3571+ public:
3572+ using OpRewritePattern::OpRewritePattern;
3573+ LogicalResult matchAndRewrite (AtenExponentialOp op,
3574+ PatternRewriter &rewriter) const override {
3575+ if (!op.getGenerator ().getType ().isa <Torch::NoneType>())
3576+ return rewriter.notifyMatchFailure (
3577+ op, " The generator has to be None because only global default "
3578+ " generator is supported" );
3579+
3580+ Location loc = op.getLoc ();
3581+ Type resultType = op.getType ();
3582+
3583+ // Create a uniform random op with low and high set to 0.0 and 1.0,
3584+ // respectively.
3585+ Value none = rewriter.create <ConstantNoneOp>(loc);
3586+ Value zero =
3587+ rewriter.create <ConstantFloatOp>(loc, rewriter.getF64FloatAttr (0.0 ));
3588+ Value one =
3589+ rewriter.create <ConstantFloatOp>(loc, rewriter.getF64FloatAttr (1.0 ));
3590+ Value emptyTensor = rewriter.create <AtenFullLikeOp>(
3591+ loc, resultType, op.getSelf (), zero, /* dtype=*/ none, /* layout=*/ none,
3592+ /* device=*/ none, /* pin_memoty=*/ none, /* memory_format=*/ none);
3593+ Value x = rewriter.create <AtenUniformOp>(loc, resultType, emptyTensor,
3594+ /* from=*/ zero, /* to=*/ one,
3595+ /* generator=*/ none);
3596+
3597+ Value negX = rewriter.create <AtenNegOp>(loc, resultType, x);
3598+ Value oneMinusX =
3599+ rewriter.create <AtenAddScalarOp>(loc, resultType, negX, one,
3600+ /* alpha=*/ one);
3601+ Value lnOneMinusX = rewriter.create <AtenLogOp>(loc, resultType, oneMinusX);
3602+ Value negLambda = rewriter.create <AtenNegFloatOp>(loc, op.getLambd ());
3603+ rewriter.replaceOpWithNewOp <AtenDivScalarOp>(op, resultType, lnOneMinusX,
3604+ negLambda);
3605+ return success ();
3606+ }
3607+ };
3608+ } // namespace
3609+
35653610namespace {
35663611template <typename OpTy, typename T1T2Op>
35673612class DecomposeAtenAddCLikeOp : public OpRewritePattern <OpTy> {
@@ -6410,6 +6455,7 @@ class DecomposeComplexOpsPass
64106455 addPatternIfTargetOpIsIllegal<
64116456 DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
64126457 addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
6458+ addPatternIfTargetOpIsIllegal<DecomposeAtenExponentialOp>(patterns);
64136459 addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
64146460 addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
64156461 addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
0 commit comments