3
3
import collections .abc
4
4
import copy
5
5
import logging
6
- from typing import Any , Optional , Sequence , Tuple
6
+ from typing import Any , List , Optional , Sequence , Tuple
7
7
8
8
import numpy as np
9
9
import tensorrt as trt
13
13
from torch_tensorrt ._Input import Input
14
14
from torch_tensorrt .dynamo import partitioning
15
15
from torch_tensorrt .dynamo ._exporter import inline_torch_modules
16
- from torch_tensorrt .dynamo .conversion import CompilationSettings
16
+ from torch_tensorrt .dynamo ._settings import CompilationSettings
17
17
from torch_tensorrt .dynamo .conversion ._conversion import infer_module_output_dtypes
18
18
from torch_tensorrt .dynamo .conversion ._ConverterRegistry import (
19
19
DYNAMO_CONVERTERS as CONVERTERS ,
@@ -108,38 +108,97 @@ def construct_refit_mapping(
108
108
return weight_map
109
109
110
110
111
+ def construct_refit_mapping_from_weight_name_map (
112
+ weight_name_map : dict [Any , Any ], state_dict : dict [Any , Any ]
113
+ ) -> dict [Any , Any ]:
114
+ engine_weight_map = {}
115
+ for engine_weight_name , (sd_weight_name , np_weight_type ) in weight_name_map .items ():
116
+ trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
117
+ torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
118
+ if engine_weight_name .split (" " )[- 1 ] in ["SCALE" , "SHIFT" ]:
119
+ # Batch Norm Layer
120
+ params = {}
121
+ for w in sd_weight_name :
122
+ params [w .split ("." )[- 1 ]] = state_dict [w ]
123
+ scale = params ["weight" ] / torch .sqrt (params ["running_var" ] + 1e-7 )
124
+ shift = params ["bias" ] - params ["running_mean" ] * scale
125
+ # Set scale to scale or shift to shift
126
+ engine_weight_map [engine_weight_name ] = eval (
127
+ engine_weight_name .split (" " )[- 1 ].lower ()
128
+ )
129
+
130
+ elif sd_weight_name not in state_dict :
131
+ # If weights is not in sd, we can leave it unchanged
132
+ continue
133
+ else :
134
+ engine_weight_map [engine_weight_name ] = state_dict [sd_weight_name ]
135
+
136
+ engine_weight_map [engine_weight_name ] = (
137
+ engine_weight_map [engine_weight_name ]
138
+ .clone ()
139
+ .reshape (- 1 )
140
+ .contiguous ()
141
+ .to (torch_dtype ),
142
+ trt_dtype ,
143
+ )
144
+
145
+ return engine_weight_map
146
+
147
+
111
148
def _refit_single_trt_engine_with_gm (
112
149
new_gm : torch .fx .GraphModule ,
113
150
old_engine : trt .ICudaEngine ,
114
- input_list : Tuple [Any , ... ],
151
+ input_list : Sequence [Any ],
115
152
settings : CompilationSettings = CompilationSettings (),
153
+ weight_name_map : Optional [dict [str , List [str ]]] = None ,
116
154
) -> None :
117
155
"""
118
156
Refit a TensorRT Engine in place
119
157
"""
120
- # Get the refitting mapping
121
- mapping = construct_refit_mapping (new_gm , input_list , settings )
158
+
122
159
refitted = set ()
123
160
124
- trt_wt_location = trt .TensorLocation .HOST
125
161
refitter = trt .Refitter (old_engine , TRT_LOGGER )
126
162
weight_list = refitter .get_all_weights ()
127
163
128
- for layer_name in weight_list :
129
- if layer_name not in mapping :
130
- raise AssertionError (f"{ layer_name } is not found in weight mapping" )
131
- # Use Numpy to create weights
132
- weight , datatype = mapping [layer_name ]
133
- trt_wt_tensor = trt .Weights (datatype , weight .ctypes .data , weight .size )
134
- refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
135
- refitted .add (layer_name )
164
+ if weight_name_map :
165
+ # Get the refitting mapping
166
+ trt_wt_location = trt .TensorLocation .DEVICE
167
+ mapping = construct_refit_mapping_from_weight_name_map (
168
+ weight_name_map , new_gm .state_dict ()
169
+ )
170
+ for layer_name in weight_list :
171
+ if layer_name not in mapping :
172
+ logger .warning (f"{ layer_name } is not found in weight mapping." )
173
+ continue
174
+ # Use Numpy to create weights
175
+ weight , weight_dtype = mapping [layer_name ]
176
+ trt_wt_tensor = trt .Weights (
177
+ weight_dtype , weight .data_ptr (), torch .numel (weight )
178
+ )
179
+ refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
180
+ assert (
181
+ len (refitter .get_missing_weights ()) == 0
182
+ ), "Fast refitting failed due to incomplete mapping"
136
183
137
- if len (refitted ) != len (weight_list ):
138
- logger .warning ("Not all weights have been refitted!!!" )
184
+ else :
185
+ mapping = construct_refit_mapping (new_gm , input_list , settings )
186
+ trt_wt_location = trt .TensorLocation .HOST
187
+ for layer_name in weight_list :
188
+ if layer_name not in mapping :
189
+ raise AssertionError (f"{ layer_name } is not found in weight mapping" )
190
+ # Use Numpy to create weights
191
+ weight , datatype = mapping [layer_name ]
192
+ trt_wt_tensor = trt .Weights (datatype , weight .ctypes .data , weight .size )
193
+ refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
194
+ refitted .add (layer_name )
195
+
196
+ if len (refitted ) != len (weight_list ):
197
+ logger .warning ("Not all weights have been refitted!!!" )
139
198
140
199
if not refitter .refit_cuda_engine ():
141
200
logger .error ("Error: failed to refit new weights." )
142
- exit ( 0 )
201
+ raise AssertionError ( "Refitting failed." )
143
202
144
203
145
204
def refit_module_weights (
@@ -148,6 +207,8 @@ def refit_module_weights(
148
207
arg_inputs : Optional [Tuple [Any , ...]] = None ,
149
208
kwarg_inputs : Optional [dict [str , Any ]] = None ,
150
209
verify_output : bool = False ,
210
+ use_weight_map_cache : bool = True ,
211
+ in_place : bool = False ,
151
212
) -> torch .fx .GraphModule :
152
213
"""
153
214
Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine.
@@ -170,7 +231,12 @@ def refit_module_weights(
170
231
if len (list (compiled_module .named_children ())) == 0 :
171
232
inline_module = True
172
233
173
- compiled_module = copy .deepcopy (compiled_module )
234
+ if not in_place :
235
+ compiled_module = copy .deepcopy (compiled_module )
236
+ elif inline_module :
237
+ raise AssertionError (
238
+ "Exported program does not support modifying in place. Please set inplace to false and use the returned graph module."
239
+ )
174
240
175
241
# Get the settings and check the setting to be uniform
176
242
settings : CompilationSettings = None
@@ -182,13 +248,14 @@ def refit_module_weights(
182
248
for name , engine in compiled_module .__dict__ .items ()
183
249
if "engine" in name
184
250
]
185
- encoded_settings = compiled_submodules [0 ][1 ].__getstate__ ()[0 ][
251
+ # [('_run_on_acc_0', inline_module)]
252
+ encoded_metadata = compiled_submodules [0 ][1 ].__getstate__ ()[0 ][
186
253
SERIALIZED_METADATA_IDX
187
254
]
188
255
assert (
189
- encoded_settings != ""
190
- ), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True. "
191
- settings = TorchTensorRTModule .decode_metadata (encoded_settings )
256
+ encoded_metadata != ""
257
+ ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
258
+ settings = TorchTensorRTModule .decode_metadata (encoded_metadata )[ "settings" ]
192
259
# Handle torch modules
193
260
compiled_submodules_map = dict (compiled_submodules )
194
261
for name , submodule in compiled_module .named_children ():
@@ -287,6 +354,7 @@ def refit_module_weights(
287
354
# Extract engine from the submodule
288
355
try :
289
356
if inline_module :
357
+ weight_name_map = None
290
358
compiled_submodule = compiled_submodules_map [name ]
291
359
# If this is a torch module, load the old state_dict
292
360
if "_run_on_acc" not in name :
@@ -297,8 +365,33 @@ def refit_module_weights(
297
365
engine = get_engine_from_encoded_engine (
298
366
engine_info [ENGINE_IDX ], runtime
299
367
)
368
+ if use_weight_map_cache :
369
+ encoded_metadata = compiled_submodule .__getstate__ ()[0 ][
370
+ SERIALIZED_METADATA_IDX
371
+ ]
372
+ weight_name_map = TorchTensorRTModule .decode_metadata (
373
+ encoded_metadata
374
+ )["weight_name_map" ]
375
+ if not weight_name_map :
376
+ use_weight_map_cache = False
377
+ logger .warning (
378
+ "This engine does not have a weight map cache. Rebuilding the weight map"
379
+ )
300
380
else :
301
381
compiled_submodule = getattr (compiled_module , name )
382
+ weight_name_map = None
383
+ if use_weight_map_cache :
384
+ try :
385
+ weight_name_map = compiled_submodule .weight_name_map
386
+ except AttributeError :
387
+ logger .warning (
388
+ "The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
389
+ )
390
+ if not weight_name_map :
391
+ use_weight_map_cache = False
392
+ logger .warning (
393
+ "This engine does not have a weight map cache. Rebuilding the weight map"
394
+ )
302
395
if isinstance (compiled_submodule , PythonTorchTensorRTModule ):
303
396
engine = compiled_submodule .engine
304
397
elif isinstance (compiled_submodule , TorchTensorRTModule ):
@@ -335,13 +428,25 @@ def refit_module_weights(
335
428
to_torch_device (settings .device ),
336
429
name ,
337
430
)
338
-
339
- _refit_single_trt_engine_with_gm (
340
- new_gm = new_submodule ,
341
- old_engine = engine ,
342
- input_list = submodule_inputs ,
343
- settings = settings ,
344
- )
431
+ try :
432
+ _refit_single_trt_engine_with_gm (
433
+ new_gm = new_submodule ,
434
+ old_engine = engine ,
435
+ input_list = submodule_inputs ,
436
+ settings = settings ,
437
+ weight_name_map = weight_name_map ,
438
+ )
439
+ except AssertionError as e :
440
+ # If fast_refit is used and failed, we fall back to regular refit
441
+ logger .warning (e )
442
+ if use_weight_map_cache and weight_name_map :
443
+ _refit_single_trt_engine_with_gm (
444
+ new_gm = new_submodule ,
445
+ old_engine = engine ,
446
+ input_list = submodule_inputs ,
447
+ settings = settings ,
448
+ weight_name_map = None ,
449
+ )
345
450
346
451
if isinstance (compiled_submodule , TorchTensorRTModule ):
347
452
serialized_engine = bytes (engine .serialize ())
0 commit comments