Skip to content

Commit bc9abbc

Browse files
Gaurav ShuklaGaurav Shukla
authored andcommitted
[TORCH][MLIR] Add E2E support for aten.empty_like op
This commit adds decomposition of `aten.empty_like` into `aten.empty` op. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent eddc09a commit bc9abbc

File tree

3 files changed

+69
-4
lines changed

3 files changed

+69
-4
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def __init__(self):
663663
None,
664664
])
665665
def forward(self):
666-
return torch.abs(torch.empty((3, 4), dtype=torch.float32)) > -1.0
666+
return torch.pow(torch.empty((3, 4), dtype=torch.float32), 0)
667667

668668
@register_test_case(module_factory=lambda: EmptyFloatModule())
669669
def EmptyModule_float(module, tu: TestUtils):
@@ -679,15 +679,51 @@ def __init__(self):
679679
None,
680680
])
681681
def forward(self):
682-
return torch.abs(torch.empty((3, 4), dtype=torch.float32,
683-
pin_memory=False)) > -1.0
682+
return torch.pow(torch.empty((3, 4), dtype=torch.float32,
683+
pin_memory=False), 0)
684684

685685
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
686686
def EmptyModule_falsePinMemory(module, tu: TestUtils):
687687
module.forward()
688688

689689
# ==============================================================================
690690

691+
class EmptyLikeIntModule(torch.nn.Module):
692+
def __init__(self):
693+
super().__init__()
694+
695+
@export
696+
@annotate_args([
697+
None,
698+
([-1, -1], torch.int64, True),
699+
])
700+
def forward(self, a):
701+
return 0 * torch.empty_like(a, dtype=torch.int64)
702+
703+
@register_test_case(module_factory=lambda: EmptyLikeIntModule())
704+
def EmptyLikeModule_int(module, tu: TestUtils):
705+
module.forward(torch.randint(10, (3, 5)))
706+
707+
# ==============================================================================
708+
709+
class EmptyLikeFloatModule(torch.nn.Module):
710+
def __init__(self):
711+
super().__init__()
712+
713+
@export
714+
@annotate_args([
715+
None,
716+
([-1, -1], torch.float32, True),
717+
])
718+
def forward(self, a):
719+
return torch.pow(torch.empty_like(a, dtype=torch.float32), 0)
720+
721+
@register_test_case(module_factory=lambda: EmptyLikeFloatModule())
722+
def EmptyLikeModule_float(module, tu: TestUtils):
723+
module.forward(tu.rand(4, 5))
724+
725+
# ==============================================================================
726+
691727
class ContiguousModule(torch.nn.Module):
692728
def __init__(self):
693729
super().__init__()

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,33 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
506506
};
507507
} // namespace
508508

509+
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
510+
namespace {
511+
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
512+
public:
513+
using OpRewritePattern::OpRewritePattern;
514+
LogicalResult matchAndRewrite(AtenEmptyLikeOp op,
515+
PatternRewriter &rewriter) const override {
516+
auto sizeListType =
517+
Torch::ListType::get(Torch::IntType::get(op.getContext()));
518+
Value sizeList =
519+
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.self());
520+
521+
// TODO: Handle the case when `dtype` is NoneType.
522+
Type dtype = op.dtype().getType();
523+
if (dtype.isa<OptionalType>() || dtype.isa<Torch::NoneType>() ||
524+
dtype.isa<mlir::NoneType>())
525+
return rewriter.notifyMatchFailure(
526+
op, "unimplemented: None dtype is not supported");
527+
528+
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
529+
op, op.getType(), sizeList, op.dtype(), op.layout(), op.device(),
530+
op.pin_memory(), op.memory_format());
531+
return success();
532+
}
533+
};
534+
} // namespace
535+
509536
namespace {
510537
class DecomposeComplexOpsPass
511538
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -521,6 +548,8 @@ class DecomposeComplexOpsPass
521548
target.addIllegalOp<Aten_SoftmaxOp>();
522549
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
523550
target.addIllegalOp<AtenLogSoftmaxIntOp>();
551+
patterns.add<DecomposeAtenEmptyLikeOp>(context);
552+
target.addIllegalOp<AtenEmptyLikeOp>();
524553
patterns.add<DecomposeAtenExpandOp>(context);
525554
target.addIllegalOp<AtenExpandOp>();
526555
patterns.add<DecomposeAtenSizeOp>(context);

lib/Dialect/Torch/Transforms/RefineTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ class TypeAnalyzer : public ForwardDataFlowAnalysis<ValueKnowledge> {
237237
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp,
238238
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp,
239239
AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
240-
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
240+
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenEmptyLikeOp,
241241
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp,
242242
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
243243
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,

0 commit comments

Comments
 (0)