Skip to content

Commit 66c5269

Browse files
committed
Fix a few test files
Signed-off-by: zjgarvey <[email protected]>
1 parent df8814c commit 66c5269

File tree

4 files changed

+137
-5060
lines changed

4 files changed

+137
-5060
lines changed

projects/e2e/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,94 @@ def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils):
10881088
module.forward(tu.rand(2, 3, 4, 4))
10891089

10901090

1091+
class UpSampleNearest2dVecNoneShape(torch.nn.Module):
1092+
def __init__(self):
1093+
super().__init__()
1094+
1095+
@export
1096+
@annotate_args(
1097+
[
1098+
None,
1099+
([-1, -1, -1, -1], torch.float64, True),
1100+
]
1101+
)
1102+
def forward(self, input):
1103+
return torch.ops.aten.upsample_nearest2d.vec(
1104+
input, output_size=None, scale_factors=[3.66, 4.2]
1105+
)
1106+
1107+
1108+
@register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneShape())
1109+
def UpSampleNearest2dVecNoneShape_basic(module, tu: TestUtils):
1110+
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
1111+
1112+
1113+
class UpSampleNearest2dVecNoneScales(torch.nn.Module):
1114+
def __init__(self):
1115+
super().__init__()
1116+
1117+
@export
1118+
@annotate_args(
1119+
[
1120+
None,
1121+
([-1, -1, -1, -1], torch.float64, True),
1122+
]
1123+
)
1124+
def forward(self, input):
1125+
return torch.ops.aten.upsample_nearest2d.vec(
1126+
input,
1127+
output_size=[18, 48],
1128+
scale_factors=None,
1129+
)
1130+
1131+
1132+
@register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneScales())
1133+
def UpSampleNearest2dVecNoneScales_basic(module, tu: TestUtils):
1134+
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))
1135+
1136+
1137+
class UpSampleNearest1dVecNoneShape(torch.nn.Module):
1138+
def __init__(self):
1139+
super().__init__()
1140+
1141+
@export
1142+
@annotate_args(
1143+
[
1144+
None,
1145+
([-1, -1, -1], torch.float64, True),
1146+
]
1147+
)
1148+
def forward(self, input):
1149+
return torch.ops.aten.upsample_nearest1d.vec(
1150+
input, output_size=None, scale_factors=[3.0]
1151+
)
1152+
1153+
1154+
@register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneShape())
1155+
def UpSampleNearest1dVecNoneShape_basic(module, tu: TestUtils):
1156+
module.forward(tu.rand(1, 1, 6).to(torch.float64))
1157+
1158+
1159+
class UpSampleNearest1dVecNoneScales(torch.nn.Module):
1160+
def __init__(self):
1161+
super().__init__()
1162+
1163+
@export
1164+
@annotate_args(
1165+
[
1166+
None,
1167+
([-1, -1, -1], torch.float64, True),
1168+
]
1169+
)
1170+
def forward(self, input):
1171+
return torch.ops.aten.upsample_nearest1d.vec(input, [18], None)
1172+
1173+
1174+
@register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneScales())
1175+
def UpSampleNearest1dVecNoneScales_basic(module, tu: TestUtils):
1176+
module.forward(tu.rand(1, 1, 6).to(torch.float64))
1177+
1178+
10911179
class Conv1dModule(torch.nn.Module):
10921180
def __init__(self):
10931181
super().__init__()

projects/e2e/torch_mlir_e2e_test/test_suite/pooling.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,55 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils):
180180
# ==============================================================================
181181

182182

183+
class MaxPool1dWithIndicesModule(torch.nn.Module):
184+
185+
def __init__(self):
186+
super().__init__()
187+
188+
@export
189+
@annotate_args(
190+
[
191+
None,
192+
([-1, -1, -1], torch.float32, True),
193+
]
194+
)
195+
def forward(self, x):
196+
return torch.ops.aten.max_pool1d_with_indices(
197+
x, kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=False
198+
)
199+
200+
201+
@register_test_case(module_factory=lambda: MaxPool1dWithIndicesModule())
202+
def MaxPool1dWithIndicesModule_basic(module, tu: TestUtils):
203+
module.forward(tu.rand(1, 64, 112, low=-1))
204+
205+
206+
class MaxPool1dWithIndicesCeilModeModule(torch.nn.Module):
207+
208+
def __init__(self):
209+
super().__init__()
210+
211+
@export
212+
@annotate_args(
213+
[
214+
None,
215+
([-1, -1, -1], torch.float32, True),
216+
]
217+
)
218+
def forward(self, x):
219+
return torch.ops.aten.max_pool1d_with_indices(
220+
x, kernel_size=[4], stride=[2], padding=[2], dilation=2, ceil_mode=True
221+
)
222+
223+
224+
@register_test_case(module_factory=lambda: MaxPool1dWithIndicesCeilModeModule())
225+
def MaxPool1dWithIndicesCeilModeModule_basic(module, tu: TestUtils):
226+
module.forward(tu.rand(3, 25, 37, low=-1))
227+
228+
229+
# ==============================================================================
230+
231+
183232
class MaxPool1dModule(torch.nn.Module):
184233

185234
def __init__(self):

0 commit comments

Comments
 (0)