48
48
logger = logging .getLogger (__name__ )
49
49
50
50
51
- @needs_refit
51
+ @needs_refit # type: ignore
52
52
def construct_refit_mapping (
53
53
module : torch .fx .GraphModule ,
54
54
inputs : Sequence [Input ],
@@ -110,7 +110,7 @@ def construct_refit_mapping(
110
110
return weight_map
111
111
112
112
113
- @needs_refit
113
+ @needs_refit # type: ignore
114
114
def construct_refit_mapping_from_weight_name_map (
115
115
weight_name_map : dict [Any , Any ],
116
116
state_dict : dict [Any , Any ],
@@ -141,7 +141,7 @@ def construct_refit_mapping_from_weight_name_map(
141
141
return engine_weight_map
142
142
143
143
144
- @needs_refit
144
+ @needs_refit # type: ignore
145
145
def _refit_single_trt_engine_with_gm (
146
146
new_gm : torch .fx .GraphModule ,
147
147
old_engine : trt .ICudaEngine ,
@@ -153,12 +153,12 @@ def _refit_single_trt_engine_with_gm(
153
153
Refit a TensorRT Engine in place
154
154
"""
155
155
156
- with unset_fake_temporarily ():
157
- refitted = set ()
158
- torch_device = get_model_device (new_gm )
159
- refitter = trt .Refitter (old_engine , TRT_LOGGER )
160
- weight_list = refitter .get_all_weights ()
156
+ refitted = set ()
157
+ torch_device = get_model_device (new_gm )
158
+ refitter = trt .Refitter (old_engine , TRT_LOGGER )
159
+ weight_list = refitter .get_all_weights ()
161
160
161
+ with unset_fake_temporarily ():
162
162
if weight_name_map :
163
163
# Get the refitting mapping
164
164
trt_wt_location = (
@@ -185,41 +185,21 @@ def _refit_single_trt_engine_with_gm(
185
185
trt_dtype ,
186
186
)
187
187
188
- constant_mapping : dict [str , Any ] = weight_name_map .pop (
189
- "constant_mapping" , {}
190
- ) # type: ignore
191
- mapping = construct_refit_mapping_from_weight_name_map (
192
- weight_name_map , new_gm .state_dict ()
193
- )
194
- constant_mapping_with_type = {}
195
-
196
- for constant_name , val in constant_mapping .items ():
197
- np_weight_type = val .dtype
198
- val_tensor = torch .from_numpy (val ).cuda ()
199
- trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
200
- torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
201
- constant_mapping_with_type [constant_name ] = (
202
- val_tensor .clone ().reshape (- 1 ).contiguous ().to (torch_dtype ),
203
- trt_dtype ,
204
- )
188
+ mapping .update (constant_mapping_with_type )
205
189
206
- mapping .update (constant_mapping_with_type )
207
-
208
- for layer_name in weight_list :
209
- if layer_name not in mapping :
210
- logger .warning (f"{ layer_name } is not found in weight mapping." )
211
- continue
212
- # Use Numpy to create weights
213
- weight , weight_dtype = mapping [layer_name ]
214
- trt_wt_tensor = trt .Weights (
215
- weight_dtype , weight .data_ptr (), torch .numel (weight )
216
- )
217
- refitter .set_named_weights (
218
- layer_name , trt_wt_tensor , trt_wt_location
219
- )
220
- assert (
221
- len (refitter .get_missing_weights ()) == 0
222
- ), "Fast refitting failed due to incomplete mapping"
190
+ for layer_name in weight_list :
191
+ if layer_name not in mapping :
192
+ logger .warning (f"{ layer_name } is not found in weight mapping." )
193
+ continue
194
+ # Use Numpy to create weights
195
+ weight , weight_dtype = mapping [layer_name ]
196
+ trt_wt_tensor = trt .Weights (
197
+ weight_dtype , weight .data_ptr (), torch .numel (weight )
198
+ )
199
+ refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
200
+ assert (
201
+ len (refitter .get_missing_weights ()) == 0
202
+ ), "Fast refitting failed due to incomplete mapping"
223
203
224
204
else :
225
205
mapping = construct_refit_mapping (new_gm , input_list , settings )
@@ -241,7 +221,7 @@ def _refit_single_trt_engine_with_gm(
241
221
raise AssertionError ("Refitting failed." )
242
222
243
223
244
- @needs_refit
224
+ @needs_refit # type: ignore
245
225
def refit_module_weights (
246
226
compiled_module : torch .fx .GraphModule | ExportedProgram ,
247
227
new_weight_module : ExportedProgram ,
0 commit comments