2
2
import torch_tensorrt as torchtrt
3
3
import torch
4
4
import torchvision .models as models
5
+ import os
5
6
6
7
def find_repo_root (max_depth = 10 ):
7
8
dir_path = os .path .dirname (os .path .realpath (__file__ ))
@@ -22,7 +23,7 @@ class TestStandardTensorInput(unittest.TestCase):
22
23
def test_compile (self ):
23
24
24
25
self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
25
- self .model = torch .jit .load (MODULE_DIR + "/standard_tensor_input .jit.pt" ).eval ().to ("cuda" )
26
+ self .model = torch .jit .load (MODULE_DIR + "/standard_tensor_input_scripted .jit.pt" ).eval ().to ("cuda" )
26
27
27
28
compile_spec = {
28
29
"inputs" : [torchtrt .Input (self .input .shape ),
@@ -41,7 +42,7 @@ class TestTupleInput(unittest.TestCase):
41
42
def test_compile (self ):
42
43
43
44
self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
44
- self .model = torch .jit .load (MODULE_DIR + "/tuple_input .jit.pt" ).eval ().to ("cuda" )
45
+ self .model = torch .jit .load (MODULE_DIR + "/tuple_input_scripted .jit.pt" ).eval ().to ("cuda" )
45
46
46
47
compile_spec = {
47
48
"input_signature" : ((torchtrt .Input (self .input .shape ), torchtrt .Input (self .input .shape )),),
@@ -61,7 +62,7 @@ class TestListInput(unittest.TestCase):
61
62
def test_compile (self ):
62
63
63
64
self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
64
- self .model = torch .jit .load (MODULE_DIR + "/list_input .jit.pt" ).eval ().to ("cuda" )
65
+ self .model = torch .jit .load (MODULE_DIR + "/list_input_scripted .jit.pt" ).eval ().to ("cuda" )
65
66
66
67
67
68
compile_spec = {
@@ -81,7 +82,7 @@ class TestTupleInputOutput(unittest.TestCase):
81
82
def test_compile (self ):
82
83
83
84
self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
84
- self .model = torch .jit .load (MODULE_DIR + "/tuple_input_output .jit.pt" ).eval ().to ("cuda" )
85
+ self .model = torch .jit .load (MODULE_DIR + "/tuple_input_output_scripted .jit.pt" ).eval ().to ("cuda" )
85
86
86
87
87
88
compile_spec = {
@@ -103,7 +104,7 @@ class TestListInputOutput(unittest.TestCase):
103
104
def test_compile (self ):
104
105
105
106
self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
106
- self .model = torch .jit .load (MODULE_DIR + "/list_input_output .jit.pt" ).eval ().to ("cuda" )
107
+ self .model = torch .jit .load (MODULE_DIR + "/list_input_output_scripted .jit.pt" ).eval ().to ("cuda" )
107
108
108
109
109
110
compile_spec = {
@@ -126,7 +127,7 @@ class TestListInputTupleOutput(unittest.TestCase):
126
127
def test_compile (self ):
127
128
128
129
self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
129
- self .model = torch .jit .load (MODULE_DIR + "/list_input_tuple_output .jit.pt" ).eval ().to ("cuda" )
130
+ self .model = torch .jit .load (MODULE_DIR + "/list_input_tuple_output_scripted .jit.pt" ).eval ().to ("cuda" )
130
131
131
132
132
133
compile_spec = {
0 commit comments