@@ -10471,6 +10471,89 @@ class DecomposeAtenNllLossForwardOp
10471
10471
};
10472
10472
} // namespace
10473
10473
10474
+ namespace {
10475
+ // Decompostion of aten.hinge_embedding_loss op
10476
+ // Ref:
10477
+ // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L182
10478
+ // The Hinge Embedding Loss:
10479
+ // | input, if target == 1
10480
+ // loss(x) = |
10481
+ // | max(0, margin - input), if target == -1
10482
+ class DecomposeHingeEmbeddingLoss
10483
+ : public OpRewritePattern<AtenHingeEmbeddingLossOp> {
10484
+ using OpRewritePattern<AtenHingeEmbeddingLossOp>::OpRewritePattern;
10485
+ LogicalResult matchAndRewrite(AtenHingeEmbeddingLossOp op,
10486
+ PatternRewriter &rewriter) const override {
10487
+ Location loc = op.getLoc();
10488
+ auto input = op.getSelf();
10489
+ auto target = op.getTarget();
10490
+
10491
+ auto inputTy = dyn_cast<ValueTensorType>(input.getType());
10492
+ if (!inputTy.hasDtype() || !inputTy.hasSizes())
10493
+ return rewriter.notifyMatchFailure(op, "input must have dtype and size");
10494
+
10495
+ auto targetTy = dyn_cast<ValueTensorType>(target.getType());
10496
+ if (!targetTy.hasDtype() || !targetTy.hasSizes())
10497
+ return rewriter.notifyMatchFailure(op, "target must have dtype and size");
10498
+ auto resultTy = dyn_cast<ValueTensorType>(op.getType());
10499
+ Value minusOne = getConstantWithGivenDtypeAndValue(rewriter, loc, -1,
10500
+ targetTy.getDtype());
10501
+ Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
10502
+ targetTy.getDtype());
10503
+ Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
10504
+ targetTy.getDtype());
10505
+ Value alpha =
10506
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
10507
+ auto boolType = targetTy.getWithSizesAndDtype(targetTy.getSizes(),
10508
+ rewriter.getI1Type());
10509
+ // input - margin
10510
+ auto inputMinusMargin = rewriter.create<AtenSubScalarOp>(
10511
+ loc, inputTy, input, op.getMargin(), alpha);
10512
+ // multiply by -1 to get margin - input
10513
+ auto marginDiff = rewriter.create<AtenMulScalarOp>(
10514
+ loc, inputTy, inputMinusMargin, minusOne);
10515
+ // max(0, margin - input) => clamping the minimum value of margin - input at
10516
+ // 0
10517
+ auto marginClamp =
10518
+ rewriter.create<AtenClampMinOp>(loc, inputTy, marginDiff, zero);
10519
+ // Compute mask: target != 1
10520
+ auto targetNotOne =
10521
+ rewriter.create<AtenNeScalarOp>(loc, boolType, target, one);
10522
+ // If target != -1 use marginClamp otherwise 0.
10523
+ auto outputMargin = rewriter.create<AtenWhereScalarOtherOp>(
10524
+ loc, inputTy, targetNotOne, marginClamp, zero);
10525
+ // Compute mask: target != 1
10526
+ auto targetNotMinusOne =
10527
+ rewriter.create<AtenNeScalarOp>(loc, boolType, target, minusOne);
10528
+ // If target != 1 use the original input. Otherwise 0.
10529
+ auto outputSelf = rewriter.create<AtenWhereScalarOtherOp>(
10530
+ loc, inputTy, targetNotMinusOne, input, zero);
10531
+ // Add : outputMargin + outputSelf
10532
+ auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin,
10533
+ outputSelf, /*alpha=*/alpha);
10534
+ int64_t reduction;
10535
+ if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) {
10536
+ return rewriter.notifyMatchFailure(op,
10537
+ "reduction should be a constant int!");
10538
+ }
10539
+ Value loss;
10540
+ Value none = rewriter.create<ConstantNoneOp>(loc);
10541
+ // reduction: mean
10542
+ if (reduction == 1) {
10543
+ loss = rewriter.create<AtenMeanOp>(loc, resultTy, output, none);
10544
+ } else if (reduction == 2) {
10545
+ // reduction: sum
10546
+ loss = rewriter.create<AtenSumOp>(loc, resultTy, output, none);
10547
+ } else {
10548
+ // reduction: none
10549
+ loss = output;
10550
+ }
10551
+ rewriter.replaceOp(op, loss);
10552
+ return success();
10553
+ }
10554
+ };
10555
+ } // namespace
10556
+
10474
10557
namespace {
10475
10558
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
10476
10559
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12384,6 +12467,7 @@ class DecomposeComplexOpsPass
12384
12467
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
12385
12468
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
12386
12469
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
12470
+ addPatternIfTargetOpIsIllegal<DecomposeHingeEmbeddingLoss>(patterns);
12387
12471
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
12388
12472
patterns);
12389
12473
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
0 commit comments