Skip to content

Commit b26d768

Browse files
committed
tests: fix test model paths
Signed-off-by: Naren Dasan <[email protected]>
1 parent 7393fa8 commit b26d768

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tests/py/api/test_collections.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch_tensorrt as torchtrt
33
import torch
44
import torchvision.models as models
5+
import os
56

67
def find_repo_root(max_depth=10):
78
dir_path = os.path.dirname(os.path.realpath(__file__))
@@ -22,7 +23,7 @@ class TestStandardTensorInput(unittest.TestCase):
2223
def test_compile(self):
2324

2425
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
25-
self.model = torch.jit.load(MODULE_DIR + "/standard_tensor_input.jit.pt").eval().to("cuda")
26+
self.model = torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt").eval().to("cuda")
2627

2728
compile_spec = {
2829
"inputs": [torchtrt.Input(self.input.shape),
@@ -41,7 +42,7 @@ class TestTupleInput(unittest.TestCase):
4142
def test_compile(self):
4243

4344
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
44-
self.model = torch.jit.load(MODULE_DIR + "/tuple_input.jit.pt").eval().to("cuda")
45+
self.model = torch.jit.load(MODULE_DIR + "/tuple_input_scripted.jit.pt").eval().to("cuda")
4546

4647
compile_spec = {
4748
"input_signature": ((torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),),
@@ -61,7 +62,7 @@ class TestListInput(unittest.TestCase):
6162
def test_compile(self):
6263

6364
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
64-
self.model = torch.jit.load(MODULE_DIR + "/list_input.jit.pt").eval().to("cuda")
65+
self.model = torch.jit.load(MODULE_DIR + "/list_input_scripted.jit.pt").eval().to("cuda")
6566

6667

6768
compile_spec = {
@@ -81,7 +82,7 @@ class TestTupleInputOutput(unittest.TestCase):
8182
def test_compile(self):
8283

8384
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
84-
self.model = torch.jit.load(MODULE_DIR + "/tuple_input_output.jit.pt").eval().to("cuda")
85+
self.model = torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt").eval().to("cuda")
8586

8687

8788
compile_spec = {
@@ -103,7 +104,7 @@ class TestListInputOutput(unittest.TestCase):
103104
def test_compile(self):
104105

105106
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
106-
self.model = torch.jit.load(MODULE_DIR + "/list_input_output.jit.pt").eval().to("cuda")
107+
self.model = torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt").eval().to("cuda")
107108

108109

109110
compile_spec = {
@@ -126,7 +127,7 @@ class TestListInputTupleOutput(unittest.TestCase):
126127
def test_compile(self):
127128

128129
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
129-
self.model = torch.jit.load(MODULE_DIR + "/list_input_tuple_output.jit.pt").eval().to("cuda")
130+
self.model = torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt").eval().to("cuda")
130131

131132

132133
compile_spec = {

0 commit comments

Comments
 (0)