@@ -10553,6 +10553,82 @@ class DecomposeAtenNllLossForwardOp
10553
10553
};
10554
10554
} // namespace
10555
10555
10556
+ namespace {
10557
+ class DecomposeAtenPoissonNllLossOp
10558
+ : public OpRewritePattern<AtenPoissonNllLossOp> {
10559
+ public:
10560
+ using OpRewritePattern::OpRewritePattern;
10561
+ LogicalResult matchAndRewrite(AtenPoissonNllLossOp op,
10562
+ PatternRewriter &rewriter) const override {
10563
+ Location loc = op.getLoc();
10564
+ Value input = op.getInput();
10565
+ Value target = op.getTarget();
10566
+ Value logInput = op.getLogInput();
10567
+ Value full = op.getFull();
10568
+ Value reduction = op.getReduction();
10569
+ Value eps = op.getEps();
10570
+
10571
+ bool logInVal, fullVal;
10572
+ if (!matchPattern(logInput, m_TorchConstantBool(&logInVal)))
10573
+ return rewriter.notifyMatchFailure(
10574
+ op, "expected logInput argument to be constant bool");
10575
+ if (!matchPattern(full, m_TorchConstantBool(&fullVal)))
10576
+ return rewriter.notifyMatchFailure(
10577
+ op, "expected full argument to be constant bool");
10578
+
10579
+ int64_t reductionInt;
10580
+ if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt)))
10581
+ return rewriter.notifyMatchFailure(op, "expected constant reduction");
10582
+
10583
+ double epsFloat;
10584
+ if (!matchPattern(eps, m_TorchConstantFloat(&epsFloat))) {
10585
+ return rewriter.notifyMatchFailure(op, "expected constant eps");
10586
+ }
10587
+ // TODO: add support for full=true (Stirling approximation)
10588
+ if (fullVal)
10589
+ return rewriter.notifyMatchFailure(
10590
+ op, "Unimplemented: full loss computation is not supported");
10591
+
10592
+ Value one =
10593
+ rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
10594
+ Value epsConst = rewriter.create<ConstantFloatOp>(
10595
+ loc, rewriter.getF64FloatAttr(epsFloat));
10596
+
10597
+ Value safeInput = rewriter.create<AtenAddScalarOp>(loc, input.getType(),
10598
+ input, epsConst, one);
10599
+
10600
+ Value loss;
10601
+ if (logInVal) {
10602
+ Value expIn = rewriter.create<AtenExpOp>(loc, input.getType(), input);
10603
+ Value targetMulInput =
10604
+ rewriter.create<AtenMulTensorOp>(loc, input.getType(), target, input);
10605
+ loss = rewriter.create<AtenSubTensorOp>(loc, input.getType(), expIn,
10606
+ targetMulInput, one);
10607
+ } else {
10608
+ Value logSafeInput =
10609
+ rewriter.create<AtenLogOp>(loc, input.getType(), safeInput);
10610
+ Value targetMulLog = rewriter.create<AtenMulTensorOp>(
10611
+ loc, input.getType(), target, logSafeInput);
10612
+ loss = rewriter.create<AtenSubTensorOp>(loc, input.getType(), input,
10613
+ targetMulLog, one);
10614
+ }
10615
+
10616
+ Value result = loss;
10617
+ if (reductionInt == 1) {
10618
+ // Case 1: Mean Reduction
10619
+ result = rewriter.create<AtenMeanOp>(
10620
+ loc, op.getType(), loss, rewriter.create<ConstantNoneOp>(loc));
10621
+ } else if (reductionInt == 2) {
10622
+ // Case 2: Sum Reduction
10623
+ result = rewriter.create<AtenSumOp>(loc, op.getType(), loss,
10624
+ rewriter.create<ConstantNoneOp>(loc));
10625
+ }
10626
+ rewriter.replaceOp(op, result);
10627
+ return success();
10628
+ }
10629
+ };
10630
+ } // namespace
10631
+
10556
10632
namespace {
10557
10633
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
10558
10634
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12467,6 +12543,7 @@ class DecomposeComplexOpsPass
12467
12543
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
12468
12544
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
12469
12545
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
12546
+ addPatternIfTargetOpIsIllegal<DecomposeAtenPoissonNllLossOp>(patterns);
12470
12547
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
12471
12548
patterns);
12472
12549
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
0 commit comments