@@ -89,6 +89,70 @@ def test_mapping():
89
89
torch ._dynamo .reset ()
90
90
91
91
92
+ @unittest .skipIf (
93
+ not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
94
+ "TorchScript Frontend is not available" ,
95
+ )
96
+ @unittest .skipIf (
97
+ not torch_trt .ENABLED_FEATURES .refit ,
98
+ "Refit feature is not supported in Python 3.13 or higher" ,
99
+ )
100
+ @unittest .skipIf (
101
+ not importlib .util .find_spec ("torchvision" ),
102
+ "torchvision is not installed" ,
103
+ )
104
+ @pytest .mark .unit
105
+ def test_conv_refit_with_weightmap ():
106
+ class net (nn .Module ):
107
+ def __init__ (self ):
108
+ super ().__init__ ()
109
+ self .conv = nn .Conv2d (3 , 3 , 1 )
110
+
111
+ def forward (self , x ):
112
+ return self .conv (x )
113
+
114
+ model = net ().eval ().to ("cuda" )
115
+ model2 = net ().eval ().to ("cuda" )
116
+ inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
117
+ enabled_precisions = {torch .float }
118
+ min_block_size = 1
119
+ use_python_runtime = True
120
+
121
+ exp_program = torch .export .export (model , tuple (inputs ))
122
+ exp_program2 = torch .export .export (model2 , tuple (inputs ))
123
+
124
+ trt_gm = torchtrt .dynamo .compile (
125
+ exp_program ,
126
+ tuple (inputs ),
127
+ use_python_runtime = use_python_runtime ,
128
+ enabled_precisions = enabled_precisions ,
129
+ min_block_size = min_block_size ,
130
+ immutable_weights = False ,
131
+ )
132
+
133
+ new_trt_gm = refit_module_weights (
134
+ compiled_module = trt_gm ,
135
+ new_weight_module = exp_program2 ,
136
+ arg_inputs = inputs ,
137
+ use_weight_map_cache = True ,
138
+ verify_output = True ,
139
+ )
140
+
141
+ # Check the output
142
+ model2 .to ("cuda" )
143
+ expected_outputs , refitted_outputs = exp_program2 .module ()(* inputs ), new_trt_gm (
144
+ * inputs
145
+ )
146
+ for expected_output , refitted_output in zip (expected_outputs , refitted_outputs ):
147
+ assertions .assertTrue (
148
+ torch .allclose (expected_output , refitted_output , 1e-2 , 1e-2 ),
149
+ "Refit Result is not correct. Refit failed" ,
150
+ )
151
+ # Clean up model env
152
+
153
+ torch ._dynamo .reset ()
154
+
155
+
92
156
@unittest .skipIf (
93
157
not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
94
158
"TorchScript Frontend is not available" ,
0 commit comments