3535@pytest .mark .unit
3636def test_mapping ():
3737
38- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
39- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
38+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
39+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
4040 inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
4141 trt_input = [
4242 torchtrt .Input (i .shape , dtype = torch .float , format = torch .contiguous_format )
@@ -58,6 +58,7 @@ def test_mapping():
5858 debug = debug ,
5959 min_block_size = min_block_size ,
6060 make_refittable = True ,
61+ reuse_cached_engines = False ,
6162 )
6263 settings = trt_gm ._run_on_acc_0 .settings
6364 runtime = trt .Runtime (TRT_LOGGER )
@@ -110,6 +111,7 @@ def test_refit_one_engine_with_weightmap():
110111 debug = debug ,
111112 min_block_size = min_block_size ,
112113 make_refittable = True ,
114+ reuse_cached_engines = False ,
113115 )
114116
115117 new_trt_gm = refit_module_weights (
@@ -141,8 +143,8 @@ def test_refit_one_engine_with_weightmap():
141143@pytest .mark .unit
142144def test_refit_one_engine_no_map_with_weightmap ():
143145
144- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
145- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
146+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
147+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
146148 inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
147149 enabled_precisions = {torch .float }
148150 debug = False
@@ -160,6 +162,7 @@ def test_refit_one_engine_no_map_with_weightmap():
160162 debug = debug ,
161163 min_block_size = min_block_size ,
162164 make_refittable = True ,
165+ reuse_cached_engines = False ,
163166 )
164167
165168 trt_gm ._run_on_acc_0 .weight_name_map = None
@@ -192,8 +195,8 @@ def test_refit_one_engine_no_map_with_weightmap():
192195@pytest .mark .unit
193196def test_refit_one_engine_with_wrong_weightmap ():
194197
195- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
196- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
198+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
199+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
197200 inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
198201 enabled_precisions = {torch .float }
199202 debug = False
@@ -211,6 +214,7 @@ def test_refit_one_engine_with_wrong_weightmap():
211214 debug = debug ,
212215 min_block_size = min_block_size ,
213216 make_refittable = True ,
217+ reuse_cached_engines = False ,
214218 )
215219 # Manually Deleted all batch norm layer. This suppose to fail the fast refit
216220 trt_gm ._run_on_acc_0 .weight_name_map = {
@@ -268,6 +272,7 @@ def test_refit_one_engine_bert_with_weightmap():
268272 debug = debug ,
269273 min_block_size = min_block_size ,
270274 make_refittable = True ,
275+ reuse_cached_engines = False ,
271276 )
272277
273278 new_trt_gm = refit_module_weights (
@@ -302,8 +307,8 @@ def test_refit_one_engine_bert_with_weightmap():
302307@pytest .mark .unit
303308def test_refit_one_engine_inline_runtime__with_weightmap ():
304309 trt_ep_path = os .path .join (tempfile .gettempdir (), "compiled.ep" )
305- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
306- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
310+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
311+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
307312 inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
308313 enabled_precisions = {torch .float }
309314 debug = False
@@ -321,6 +326,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
321326 debug = debug ,
322327 min_block_size = min_block_size ,
323328 make_refittable = True ,
329+ reuse_cached_engines = False ,
324330 )
325331 torchtrt .save (trt_gm , trt_ep_path , inputs = inputs )
326332 trt_gm = torch .export .load (trt_ep_path )
@@ -348,8 +354,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
348354@pytest .mark .unit
349355def test_refit_one_engine_python_runtime_with_weightmap ():
350356
351- model = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
352- model2 = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
357+ model = models .resnet18 (pretrained = False ).eval ().to ("cuda" )
358+ model2 = models .resnet18 (pretrained = True ).eval ().to ("cuda" )
353359 inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
354360 enabled_precisions = {torch .float }
355361 debug = False
@@ -367,6 +373,7 @@ def test_refit_one_engine_python_runtime_with_weightmap():
367373 debug = debug ,
368374 min_block_size = min_block_size ,
369375 make_refittable = True ,
376+ reuse_cached_engines = False ,
370377 )
371378
372379 new_trt_gm = refit_module_weights (
@@ -438,6 +445,7 @@ def forward(self, x):
438445 min_block_size = min_block_size ,
439446 make_refittable = True ,
440447 torch_executed_ops = torch_executed_ops ,
448+ reuse_cached_engines = False ,
441449 )
442450
443451 new_trt_gm = refit_module_weights (
@@ -487,6 +495,7 @@ def test_refit_one_engine_without_weightmap():
487495 debug = debug ,
488496 min_block_size = min_block_size ,
489497 make_refittable = True ,
498+ reuse_cached_engines = False ,
490499 )
491500
492501 new_trt_gm = refit_module_weights (
@@ -538,6 +547,7 @@ def test_refit_one_engine_bert_without_weightmap():
538547 debug = debug ,
539548 min_block_size = min_block_size ,
540549 make_refittable = True ,
550+ reuse_cached_engines = False ,
541551 )
542552
543553 new_trt_gm = refit_module_weights (
@@ -591,6 +601,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
591601 debug = debug ,
592602 min_block_size = min_block_size ,
593603 make_refittable = True ,
604+ reuse_cached_engines = False ,
594605 )
595606 torchtrt .save (trt_gm , trt_ep_path , inputs = inputs )
596607 trt_gm = torch .export .load (trt_ep_path )
@@ -637,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap():
637648 debug = debug ,
638649 min_block_size = min_block_size ,
639650 make_refittable = True ,
651+ reuse_cached_engines = False ,
640652 )
641653
642654 new_trt_gm = refit_module_weights (
@@ -708,6 +720,7 @@ def forward(self, x):
708720 min_block_size = min_block_size ,
709721 make_refittable = True ,
710722 torch_executed_ops = torch_executed_ops ,
723+ reuse_cached_engines = False ,
711724 )
712725
713726 new_trt_gm = refit_module_weights (
0 commit comments