@@ -11304,6 +11304,76 @@ class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
11304
11304
};
11305
11305
} // namespace
11306
11306
11307
+ namespace {
11308
+ // Decomposed aten.heaviside op into
11309
+ // using aten.eq, aten.lt, aten.logical_or, aten.where
11310
+ // Heaviside(x, y) returns
11311
+ // 0 if x < 0
11312
+ // y if x == 0
11313
+ // 1 if x > 0
11314
+ class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
11315
+ public:
11316
+ using OpRewritePattern::OpRewritePattern;
11317
+ LogicalResult matchAndRewrite(AtenHeavisideOp op,
11318
+ PatternRewriter &rewriter) const override {
11319
+ auto input = op.getSelf();
11320
+ auto value = op.getValues();
11321
+ auto loc = op.getLoc();
11322
+ auto inputTy = dyn_cast<BaseTensorType>(input.getType());
11323
+ if (!inputTy || !inputTy.hasDtype() || !inputTy.hasSizes())
11324
+ return rewriter.notifyMatchFailure(op, "input must have dtype and size.");
11325
+
11326
+ auto valueTy = dyn_cast<BaseTensorType>(value.getType());
11327
+ if (!valueTy || !valueTy.hasDtype() || !valueTy.hasSizes())
11328
+ return rewriter.notifyMatchFailure(op, "value must have dtype and size.");
11329
+ auto resultTy = dyn_cast<BaseTensorType>(op.getType());
11330
+ SmallVector<int64_t> broadcastShape;
11331
+ SmallVector<Value> broadcastShapeValue;
11332
+ computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
11333
+ broadcastShapeValue);
11334
+
11335
+ auto broadcastType = ValueTensorType::get(
11336
+ op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype());
11337
+ auto boolBroadcastType = ValueTensorType::get(
11338
+ op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type());
11339
+ Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
11340
+ loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
11341
+ broadcastShapeValue);
11342
+ auto inputBroadcasted = rewriter.create<AtenBroadcastToOp>(
11343
+ loc, broadcastType, input, indexBroadcastShapeTorchList);
11344
+ auto valueBroadcasted = rewriter.create<AtenBroadcastToOp>(
11345
+ loc, broadcastType, value, indexBroadcastShapeTorchList);
11346
+
11347
+ Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
11348
+ resultTy.getDtype());
11349
+ Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
11350
+ resultTy.getDtype());
11351
+ // Compute mask: input == 0
11352
+ auto inputEqZero = rewriter
11353
+ .create<AtenEqScalarOp>(loc, boolBroadcastType,
11354
+ inputBroadcasted, zero)
11355
+ ->getResult(0);
11356
+ // Compute mask: input < 0
11357
+ auto inputLtZero = rewriter.create<AtenLtScalarOp>(loc, boolBroadcastType,
11358
+ inputBroadcasted, zero);
11359
+ // Compute mask: isnan(input)
11360
+ auto isNan =
11361
+ rewriter.create<AtenIsnanOp>(loc, boolBroadcastType, inputBroadcasted);
11362
+ // Combine: input < 0 || isnan(input)
11363
+ auto inputNegativeOrNan = rewriter.create<AtenLogicalOrOp>(
11364
+ loc, boolBroadcastType, inputLtZero, isNan);
11365
+ // Select 0 if input < 0 or input is nan, else 1
11366
+ auto zerosOrOnes = rewriter.create<AtenWhereScalarOp>(
11367
+ loc, resultTy, inputNegativeOrNan, zero, one);
11368
+ // Final result: if input == 0, take from valueBroadcasted, else take from
11369
+ // zerosOrOnes
11370
+ rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resultTy, inputEqZero,
11371
+ valueBroadcasted, zerosOrOnes);
11372
+ return success();
11373
+ }
11374
+ };
11375
+ } // namespace
11376
+
11307
11377
namespace {
11308
11378
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
11309
11379
// `torch.to.dtype`.
@@ -12528,6 +12598,7 @@ class DecomposeComplexOpsPass
12528
12598
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
12529
12599
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
12530
12600
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
12601
+ addPatternIfTargetOpIsIllegal<DecomposeAtenHeaviside>(patterns);
12531
12602
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
12532
12603
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
12533
12604
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
0 commit comments