Skip to content

Commit 3c40539

Browse files
author
Gaurav Shukla
committed
[TORCH][MLIR] Add E2E support for aten.[ones_like|zeros_like]
- This commit adds E2E support for `aten.ones_like` and `aten.zeros_like` ops. - Adds support for non-None `dtype` argument of `aten.empty_like` op. - All the unit test cases related to constant tensor allocation like ops are moved to a different file named `constant_alloc.py`. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 9afaace commit 3c40539

File tree

7 files changed

+501
-284
lines changed

7 files changed

+501
-284
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 3 additions & 272 deletions
Original file line numberDiff line numberDiff line change
@@ -586,144 +586,6 @@ def ExpandModule_basic(module, tu: TestUtils):
586586

587587
# ==============================================================================
588588

589-
590-
class OnesModuleInt(torch.nn.Module):
591-
def __init__(self):
592-
super().__init__()
593-
594-
@export
595-
@annotate_args([
596-
None,
597-
])
598-
def forward(self):
599-
return torch.ones(3, 4, dtype=torch.int64)
600-
601-
@register_test_case(module_factory=lambda: OnesModuleInt())
602-
def OnesModuleInt_basic(module, tu: TestUtils):
603-
module.forward()
604-
605-
# ==============================================================================
606-
607-
class OnesModuleFloat(torch.nn.Module):
608-
def __init__(self):
609-
super().__init__()
610-
611-
@export
612-
@annotate_args([
613-
None,
614-
])
615-
def forward(self):
616-
return torch.ones(3, 4, dtype=torch.float32)
617-
618-
@register_test_case(module_factory=lambda: OnesModuleFloat())
619-
def OnesModuleFloat_basic(module, tu: TestUtils):
620-
module.forward()
621-
622-
623-
class OnesModuleFalsePinMemory(torch.nn.Module):
624-
def __init__(self):
625-
super().__init__()
626-
627-
@export
628-
@annotate_args([
629-
None,
630-
])
631-
def forward(self):
632-
return torch.ones(3, 4, dtype=torch.float32, pin_memory=False)
633-
634-
@register_test_case(module_factory=lambda: OnesModuleFalsePinMemory())
635-
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
636-
module.forward()
637-
638-
# ==============================================================================
639-
640-
class EmptyIntModule(torch.nn.Module):
641-
def __init__(self):
642-
super().__init__()
643-
644-
@export
645-
@annotate_args([
646-
None,
647-
])
648-
def forward(self):
649-
return 0 * torch.empty((3, 4), dtype=torch.int64)
650-
651-
@register_test_case(module_factory=lambda: EmptyIntModule())
652-
def EmptyModule_int(module, tu: TestUtils):
653-
module.forward()
654-
655-
# ==============================================================================
656-
657-
class EmptyFloatModule(torch.nn.Module):
658-
def __init__(self):
659-
super().__init__()
660-
661-
@export
662-
@annotate_args([
663-
None,
664-
])
665-
def forward(self):
666-
return torch.pow(torch.empty((3, 4), dtype=torch.float32), 0)
667-
668-
@register_test_case(module_factory=lambda: EmptyFloatModule())
669-
def EmptyModule_float(module, tu: TestUtils):
670-
module.forward()
671-
672-
673-
class EmptyFalsePinMemoryModule(torch.nn.Module):
674-
def __init__(self):
675-
super().__init__()
676-
677-
@export
678-
@annotate_args([
679-
None,
680-
])
681-
def forward(self):
682-
return torch.pow(torch.empty((3, 4), dtype=torch.float32,
683-
pin_memory=False), 0)
684-
685-
@register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule())
686-
def EmptyModule_falsePinMemory(module, tu: TestUtils):
687-
module.forward()
688-
689-
# ==============================================================================
690-
691-
class EmptyLikeIntModule(torch.nn.Module):
692-
def __init__(self):
693-
super().__init__()
694-
695-
@export
696-
@annotate_args([
697-
None,
698-
([-1, -1], torch.int64, True),
699-
])
700-
def forward(self, a):
701-
return 0 * torch.empty_like(a, dtype=torch.int64)
702-
703-
@register_test_case(module_factory=lambda: EmptyLikeIntModule())
704-
def EmptyLikeModule_int(module, tu: TestUtils):
705-
module.forward(torch.randint(10, (3, 5)))
706-
707-
# ==============================================================================
708-
709-
class EmptyLikeFloatModule(torch.nn.Module):
710-
def __init__(self):
711-
super().__init__()
712-
713-
@export
714-
@annotate_args([
715-
None,
716-
([-1, -1], torch.float32, True),
717-
])
718-
def forward(self, a):
719-
return torch.pow(torch.empty_like(a, dtype=torch.float32), 0)
720-
721-
@register_test_case(module_factory=lambda: EmptyLikeFloatModule())
722-
def EmptyLikeModule_float(module, tu: TestUtils):
723-
module.forward(tu.rand(4, 5))
724-
725-
# ==============================================================================
726-
727589
class ContiguousModule(torch.nn.Module):
728590
def __init__(self):
729591
super().__init__()
@@ -926,57 +788,6 @@ def DropoutModule_basic(module, tu: TestUtils):
926788
module.forward(tu.rand(3, 4))
927789

928790

929-
class Fill_TensorFloat64WithFloat32(torch.nn.Module):
930-
def __init__(self):
931-
super().__init__()
932-
933-
@export
934-
@annotate_args([
935-
None,
936-
([-1, -1, -1], torch.float32, True),
937-
])
938-
def forward(self, tensor):
939-
return torch.ops.aten.fill_(tensor, 3.0)
940-
941-
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32())
942-
def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils):
943-
module.forward(torch.randn(3, 2, 4))
944-
945-
946-
class Fill_TensorFloat64WithFloat64(torch.nn.Module):
947-
def __init__(self):
948-
super().__init__()
949-
950-
@export
951-
@annotate_args([
952-
None,
953-
([-1, -1, -1], torch.float64, True),
954-
])
955-
def forward(self, tensor):
956-
return torch.ops.aten.fill_(tensor, 3.0)
957-
958-
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64())
959-
def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils):
960-
module.forward(torch.randn(3, 2, 4).to(torch.float64))
961-
962-
963-
class Fill_TensorFloat64WithInt64(torch.nn.Module):
964-
def __init__(self):
965-
super().__init__()
966-
967-
@export
968-
@annotate_args([
969-
None,
970-
([-1, -1, -1], torch.float64, True),
971-
])
972-
def forward(self, tensor):
973-
return torch.ops.aten.fill_(tensor, 3)
974-
975-
@register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64())
976-
def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils):
977-
module.forward(torch.randn(3, 2, 4).to(torch.float64))
978-
979-
980791
class MeanModule(torch.nn.Module):
981792
def __init__(self):
982793
super().__init__()
@@ -1047,86 +858,6 @@ def NumelZeroRankModule_basic(module, tu: TestUtils):
1047858
module.forward(torch.randint(10,[]))
1048859

1049860

1050-
class ZerosModuleInt2D(torch.nn.Module):
1051-
def __init__(self):
1052-
super().__init__()
1053-
1054-
@export
1055-
@annotate_args([
1056-
None,
1057-
])
1058-
def forward(self):
1059-
return torch.zeros(3, 4, dtype=torch.int64)
1060-
1061-
@register_test_case(module_factory=lambda: ZerosModuleInt2D())
1062-
def ZerosModuleInt2D_basic(module, tu: TestUtils):
1063-
module.forward()
1064-
1065-
1066-
class ZerosModuleInt3D(torch.nn.Module):
1067-
def __init__(self):
1068-
super().__init__()
1069-
1070-
@export
1071-
@annotate_args([
1072-
None,
1073-
])
1074-
def forward(self):
1075-
return torch.zeros(3, 4, 5, dtype=torch.int64)
1076-
1077-
@register_test_case(module_factory=lambda: ZerosModuleInt3D())
1078-
def ZerosModuleInt3D_basic(module, tu: TestUtils):
1079-
module.forward()
1080-
1081-
1082-
class ZerosModuleFloat2D(torch.nn.Module):
1083-
def __init__(self):
1084-
super().__init__()
1085-
1086-
@export
1087-
@annotate_args([
1088-
None,
1089-
])
1090-
def forward(self):
1091-
return torch.zeros(3, 4, dtype=torch.float32)
1092-
1093-
@register_test_case(module_factory=lambda: ZerosModuleFloat2D())
1094-
def ZerosModuleFloat2D_basic(module, tu: TestUtils):
1095-
module.forward()
1096-
1097-
1098-
class ZerosModuleFloat3D(torch.nn.Module):
1099-
def __init__(self):
1100-
super().__init__()
1101-
1102-
@export
1103-
@annotate_args([
1104-
None,
1105-
])
1106-
def forward(self):
1107-
return torch.zeros(3, 4, 5, dtype=torch.float32)
1108-
1109-
@register_test_case(module_factory=lambda: ZerosModuleFloat3D())
1110-
def ZerosModuleFloat3D_basic(module, tu: TestUtils):
1111-
module.forward()
1112-
1113-
1114-
class ZerosModuleFalsePinMemory(torch.nn.Module):
1115-
def __init__(self):
1116-
super().__init__()
1117-
1118-
@export
1119-
@annotate_args([
1120-
None,
1121-
])
1122-
def forward(self):
1123-
return torch.zeros(3, 4, dtype=torch.float32, pin_memory=False)
1124-
1125-
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
1126-
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
1127-
module.forward()
1128-
1129-
1130861
class BoolTensorReturnFalseModule(torch.nn.Module):
1131862
def __init__(self):
1132863
super().__init__()
@@ -1181,6 +912,7 @@ def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
1181912
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
1182913

1183914
# ==============================================================================
915+
1184916
class TModuleRank2(torch.nn.Module):
1185917
def __init__(self):
1186918
super().__init__()
@@ -1193,11 +925,11 @@ def __init__(self):
1193925
def forward(self, lhs):
1194926
return torch.t(lhs)
1195927

1196-
1197928
@register_test_case(module_factory=lambda: TModuleRank2())
1198929
def TModuleRank2_basic(module, tu: TestUtils):
1199930
module.forward(tu.rand(3, 4))
1200931

932+
1201933
class TModuleRank1(torch.nn.Module):
1202934
def __init__(self):
1203935
super().__init__()
@@ -1210,11 +942,11 @@ def __init__(self):
1210942
def forward(self, lhs):
1211943
return torch.t(lhs)
1212944

1213-
1214945
@register_test_case(module_factory=lambda: TModuleRank1())
1215946
def TModuleRank1_basic(module, tu: TestUtils):
1216947
module.forward(tu.rand(3))
1217948

949+
1218950
class TModuleRank0(torch.nn.Module):
1219951
def __init__(self):
1220952
super().__init__()
@@ -1227,7 +959,6 @@ def __init__(self):
1227959
def forward(self, lhs):
1228960
return torch.t(lhs)
1229961

1230-
1231962
@register_test_case(module_factory=lambda: TModuleRank0())
1232963
def TModuleRank0_basic(module, tu: TestUtils):
1233964
module.forward(torch.tensor(7, dtype=torch.float32))

0 commit comments

Comments
 (0)