Skip to content

Commit 4a48032

Browse files
committed
reenable back thor test
1 parent ac478ea commit 4a48032

File tree

11 files changed

+103
-53
lines changed

11 files changed

+103
-53
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.fx.node import Argument, Node, Target
1111
from torch_tensorrt import ENABLED_FEATURES
1212
from torch_tensorrt._features import needs_not_tensorrt_rtx
13-
from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor
13+
from torch_tensorrt._utils import is_tensorrt_version_supported
1414
from torch_tensorrt.dynamo._settings import CompilationSettings
1515
from torch_tensorrt.dynamo._SourceIR import SourceIR
1616
from torch_tensorrt.dynamo.conversion import impl
@@ -429,7 +429,7 @@ def index_nonbool_validator(
429429
node: Node, settings: Optional[CompilationSettings] = None
430430
) -> bool:
431431
# for thor and tensorrt_rtx, we don't support boolean indices, due to nonzero op not supported
432-
if is_thor() or ENABLED_FEATURES.tensorrt_rtx:
432+
if ENABLED_FEATURES.tensorrt_rtx:
433433
index = node.args[1]
434434
for ind in index:
435435
if ind is not None:
@@ -3621,18 +3621,10 @@ def aten_ops_full(
36213621
)
36223622

36233623

3624-
def nonzero_validator(
3625-
node: Node, settings: Optional[CompilationSettings] = None
3626-
) -> bool:
3627-
return not is_thor()
3628-
3629-
36303624
# currently nonzero is not supported for tensorrt_rtx
36313625
# TODO: lan to add the nonzero support once tensorrt_rtx team has added the support
3632-
# TODO: apbose to remove the capability validator once thor bug resolve in NGC
36333626
@dynamo_tensorrt_converter(
36343627
torch.ops.aten.nonzero.default,
3635-
capability_validator=nonzero_validator,
36363628
supports_dynamic_shapes=True,
36373629
requires_output_allocator=True,
36383630
)

tests/py/dynamo/conversion/test_arange_aten.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import torch_tensorrt
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
8-
from torch_tensorrt._utils import is_tegra_platform, is_thor
8+
from torch_tensorrt._utils import is_tegra_platform
99

1010
from .harness import DispatchTestCase
1111

1212

1313
@unittest.skipIf(
14-
is_thor() or is_tegra_platform(),
15-
"Skipped on Thor and Tegra platforms",
14+
is_tegra_platform(),
15+
"Skipped on Tegra platforms",
1616
)
1717
class TestArangeConverter(DispatchTestCase):
1818
@parameterized.expand(

tests/py/dynamo/conversion/test_cumsum_aten.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import torch_tensorrt
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
8-
from torch_tensorrt._utils import is_tegra_platform, is_thor
8+
from torch_tensorrt._utils import is_tegra_platform
99

1010
from .harness import DispatchTestCase
1111

1212

1313
@unittest.skipIf(
14-
is_thor() or is_tegra_platform(),
15-
"Skipped on Thor and Tegra platforms",
14+
is_tegra_platform(),
15+
"Skipped on Tegra platforms",
1616
)
1717
class TestCumsumConverter(DispatchTestCase):
1818
@parameterized.expand(

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
88
from torch_tensorrt import ENABLED_FEATURES, Input
9-
from torch_tensorrt._utils import is_tegra_platform, is_thor
9+
from torch_tensorrt._utils import is_tegra_platform
1010

1111
from .harness import DispatchTestCase
1212

@@ -114,8 +114,8 @@ def forward(self, input):
114114
]
115115
)
116116
@unittest.skipIf(
117-
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
118-
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
117+
ENABLED_FEATURES.tensorrt_rtx,
118+
"Skipped on tensorrt_rtx due to nonzero not supported",
119119
)
120120
def test_index_constant_bool_mask(self, _, index, input):
121121
class TestModule(torch.nn.Module):
@@ -149,8 +149,8 @@ def forward(self, x, index0):
149149
)
150150

151151
@unittest.skipIf(
152-
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
153-
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
152+
ENABLED_FEATURES.tensorrt_rtx,
153+
"Skipped on tensorrt_rtx due to nonzero not supported",
154154
)
155155
def test_index_zero_two_dim_ITensor_mask(self):
156156
class TestModule(nn.Module):
@@ -163,10 +163,6 @@ def forward(self, x, index0):
163163
index0 = torch.tensor([True, False])
164164
self.run_test(TestModule(), [input, index0], enable_passes=True)
165165

166-
@unittest.skipIf(
167-
is_thor(),
168-
"Skipped on Thor due to nonzero not supported",
169-
)
170166
def test_index_zero_index_three_dim_ITensor(self):
171167
class TestModule(nn.Module):
172168
def forward(self, x, index0):
@@ -180,8 +176,8 @@ def forward(self, x, index0):
180176
self.run_test(TestModule(), [input, index0])
181177

182178
@unittest.skipIf(
183-
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
184-
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
179+
ENABLED_FEATURES.tensorrt_rtx,
180+
"Skipped on tensorrt_rtx due to nonzero not supported",
185181
)
186182
def test_index_zero_index_three_dim_mask_ITensor(self):
187183
class TestModule(nn.Module):
@@ -252,8 +248,8 @@ def forward(self, input):
252248

253249

254250
@unittest.skipIf(
255-
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx or is_thor() or is_tegra_platform(),
256-
"nonzero is not supported for tensorrt_rtx",
251+
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx or is_tegra_platform(),
252+
"nonzero is not supported for tensorrt_rtx or Tegra platforms",
257253
)
258254
class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):
259255
def test_index_input_non_dynamic_index_dynamic(self):

tests/py/dynamo/conversion/test_nonzero_aten.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from parameterized import parameterized
77
from torch.testing._internal.common_utils import run_tests
88
from torch_tensorrt import Input
9-
from torch_tensorrt._utils import is_tegra_platform, is_thor
9+
from torch_tensorrt._utils import is_tegra_platform
1010

1111
from .harness import DispatchTestCase
1212

1313

1414
@unittest.skipIf(
15-
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx or is_thor() or is_tegra_platform(),
16-
"nonzero is not supported for tensorrt_rtx",
15+
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx or is_tegra_platform(),
16+
"nonzero is not supported for tensorrt_rtx or Tegra platforms",
1717
)
1818
class TestNonZeroConverter(DispatchTestCase):
1919
@parameterized.expand(

tests/py/dynamo/conversion/test_sym_size.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@
44
import torch.nn as nn
55
from parameterized import parameterized
66
from torch.testing._internal.common_utils import run_tests
7-
from torch_tensorrt._utils import is_thor
87

98
from .harness import DispatchTestCase
109

1110

12-
@unittest.skipIf(
13-
is_thor(),
14-
"Skipped on Thor",
15-
)
1611
class TestSymSizeConverter(DispatchTestCase):
1712
@parameterized.expand(
1813
[

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def forward(self, x, b=5, c=None, d=None):
7575
)
7676

7777
# Save the module
78-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
78+
tmp_dir = tempfile.mkdtemp(prefix="test_custom_model")
79+
trt_ep_path = os.path.join(tmp_dir, "compiled.ep")
7980
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
8081
# Clean up model env
8182
torch._dynamo.reset()
@@ -137,7 +138,8 @@ def forward(self, x, b=5, c=None, d=None):
137138
)
138139

139140
# Save the module
140-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
141+
tmp_dir = tempfile.mkdtemp(prefix="test_custom_model_with_dynamo_trace")
142+
trt_ep_path = os.path.join(tmp_dir, "compiled.ep")
141143
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
142144
# Clean up model env
143145
torch._dynamo.reset()
@@ -208,7 +210,8 @@ def forward(self, x, b=5, c=None, d=None):
208210
)
209211

210212
# Save the module
211-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
213+
tmp_dir = tempfile.mkdtemp(prefix="test_custom_model_with_dynamo_trace_dynamic")
214+
trt_ep_path = os.path.join(tmp_dir, "compiled.ep")
212215
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
213216
# Clean up model env
214217
torch._dynamo.reset()
@@ -298,7 +301,10 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
298301
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
299302
)
300303
# Save the module
301-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
304+
tmp_dir = tempfile.mkdtemp(
305+
prefix="test_custom_model_with_dynamo_trace_kwarg_dynamic"
306+
)
307+
trt_ep_path = os.path.join(tmp_dir, "compiled.ep")
302308
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
303309
# Clean up model env
304310
torch._dynamo.reset()
@@ -388,7 +394,10 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
388394
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
389395
)
390396
# Save the module
391-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
397+
tmp_dir = tempfile.mkdtemp(
398+
prefix="test_custom_model_with_dynamo_trace_kwarg_dynamic"
399+
)
400+
trt_ep_path = os.path.join(tmp_dir, "compiled.ep")
392401
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
393402
# Clean up model env
394403
torch._dynamo.reset()

tests/py/dynamo/models/test_export_serde.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
if importlib.util.find_spec("torchvision"):
1818
import torchvision.models as models
1919

20-
trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep")
21-
2220

2321
@pytest.mark.unit
2422
@pytest.mark.critical
@@ -27,6 +25,8 @@ def test_base_full_compile(ir):
2725
This tests export serde functionality on a base model
2826
which is fully TRT convertible
2927
"""
28+
tmp_dir = tempfile.mkdtemp(prefix="test_base_full_compile")
29+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
3030

3131
class MyModule(torch.nn.Module):
3232
def __init__(self):
@@ -82,6 +82,9 @@ def test_base_full_compile_multiple_outputs(ir):
8282
with multiple outputs which is fully TRT convertible
8383
"""
8484

85+
tmp_dir = tempfile.mkdtemp(prefix="test_base_full_compile_multiple_outputs")
86+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
87+
8588
class MyModule(torch.nn.Module):
8689
def __init__(self):
8790
super().__init__()
@@ -141,6 +144,8 @@ def test_no_compile(ir):
141144
This tests export serde functionality on a model
142145
which won't convert to TRT because of min_block_size=5 constraint
143146
"""
147+
tmp_dir = tempfile.mkdtemp(prefix="test_no_compile")
148+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
144149

145150
class MyModule(torch.nn.Module):
146151
def __init__(self):
@@ -202,6 +207,9 @@ def test_hybrid_relu_fallback(ir):
202207
fallback
203208
"""
204209

210+
tmp_dir = tempfile.mkdtemp(prefix="test_hybrid_relu_fallback")
211+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
212+
205213
class MyModule(torch.nn.Module):
206214
def __init__(self):
207215
super().__init__()
@@ -262,6 +270,9 @@ def test_resnet18(ir):
262270
"""
263271
This tests export save and load functionality on Resnet18 model
264272
"""
273+
tmp_dir = tempfile.mkdtemp(prefix="test_resnet18")
274+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
275+
265276
model = models.resnet18().eval().cuda()
266277
input = torch.randn((1, 3, 224, 224)).to("cuda")
267278

@@ -307,6 +318,9 @@ def test_resnet18_cpu_offload(ir):
307318
"""
308319
This tests export save and load functionality on Resnet18 model
309320
"""
321+
tmp_dir = tempfile.mkdtemp(prefix="test_resnet18_cpu_offload")
322+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
323+
310324
model = models.resnet18().eval().cuda()
311325
input = torch.randn((1, 3, 224, 224)).to("cuda")
312326

@@ -359,6 +373,9 @@ def test_resnet18_dynamic(ir):
359373
"""
360374
This tests export save and load functionality on Resnet18 model
361375
"""
376+
tmp_dir = tempfile.mkdtemp(prefix="test_resnet18_dynamic")
377+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
378+
362379
model = models.resnet18().eval().cuda()
363380
input = torch.randn((1, 3, 224, 224)).to("cuda")
364381

@@ -399,6 +416,9 @@ def test_resnet18_torch_exec_ops_serde(ir):
399416
"""
400417
This tests export save and load functionality on Resnet18 model
401418
"""
419+
tmp_dir = tempfile.mkdtemp(prefix="test_resnet18_torch_exec_ops_serde")
420+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
421+
402422
model = models.resnet18().eval().cuda()
403423
input = torch.randn((1, 3, 224, 224)).to("cuda")
404424

@@ -432,6 +452,9 @@ def test_hybrid_conv_fallback(ir):
432452
model where a conv (a weighted layer) has been forced to fallback to Pytorch.
433453
"""
434454

455+
tmp_dir = tempfile.mkdtemp(prefix="test_hybrid_conv_fallback")
456+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
457+
435458
class MyModule(torch.nn.Module):
436459
def __init__(self):
437460
super().__init__()
@@ -493,6 +516,9 @@ def test_hybrid_conv_fallback_cpu_offload(ir):
493516
model where a conv (a weighted layer) has been forced to fallback to Pytorch.
494517
"""
495518

519+
tmp_dir = tempfile.mkdtemp(prefix="test_hybrid_conv_fallback_cpu_offload")
520+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
521+
496522
class MyModule(torch.nn.Module):
497523
def __init__(self):
498524
super().__init__()
@@ -555,6 +581,8 @@ def test_arange_export(ir):
555581
Here the arange output is a static constant (which is registered as input to the graph)
556582
in the exporter.
557583
"""
584+
tmp_dir = tempfile.mkdtemp(prefix="test_arange_export")
585+
trt_ep_path = os.path.join(tmp_dir, "trt.ep")
558586

559587
class MyModule(torch.nn.Module):
560588
def __init__(self):

tests/py/dynamo/models/test_model_refit.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,10 @@ def test_refit_one_engine_bert_with_weightmap():
532532
)
533533
@pytest.mark.unit
534534
def test_refit_one_engine_inline_runtime_with_weightmap():
535-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
535+
tmp_dir = tempfile.mkdtemp(
536+
prefix="test_refit_one_engine_inline_runtime_with_weightmap"
537+
)
538+
trt_ep_path = os.path.join(tmp_dir, "compiled.ep")
536539
model = models.resnet18(pretrained=False).eval().to("cuda")
537540
model2 = models.resnet18(pretrained=True).eval().to("cuda")
538541
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
@@ -889,7 +892,10 @@ def test_refit_one_engine_bert_without_weightmap():
889892
)
890893
@pytest.mark.unit
891894
def test_refit_one_engine_inline_runtime_without_weightmap():
892-
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
895+
tmp_dir = tempfile.mkdtemp(
896+
prefix="test_refit_one_engine_inline_runtime_without_weightmap"
897+
)
898+
trt_ep_path = os.path.join(tmp_dir, "compiled.ep")
893899
model = models.resnet18(pretrained=True).eval().to("cuda")
894900
model2 = models.resnet18(pretrained=False).eval().to("cuda")
895901
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]

0 commit comments

Comments
 (0)