Skip to content

Commit 3ee2d81

Browse files
authored
Added refitting acceleration (#2983)
1 parent e90576a commit 3ee2d81

File tree

7 files changed

+732
-58
lines changed

7 files changed

+732
-58
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 134 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import collections.abc
44
import copy
55
import logging
6-
from typing import Any, Optional, Sequence, Tuple
6+
from typing import Any, List, Optional, Sequence, Tuple
77

88
import numpy as np
99
import tensorrt as trt
@@ -13,7 +13,7 @@
1313
from torch_tensorrt._Input import Input
1414
from torch_tensorrt.dynamo import partitioning
1515
from torch_tensorrt.dynamo._exporter import inline_torch_modules
16-
from torch_tensorrt.dynamo.conversion import CompilationSettings
16+
from torch_tensorrt.dynamo._settings import CompilationSettings
1717
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
1818
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
1919
DYNAMO_CONVERTERS as CONVERTERS,
@@ -108,38 +108,97 @@ def construct_refit_mapping(
108108
return weight_map
109109

110110

111+
def construct_refit_mapping_from_weight_name_map(
112+
weight_name_map: dict[Any, Any], state_dict: dict[Any, Any]
113+
) -> dict[Any, Any]:
114+
engine_weight_map = {}
115+
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
116+
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
117+
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
118+
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
119+
# Batch Norm Layer
120+
params = {}
121+
for w in sd_weight_name:
122+
params[w.split(".")[-1]] = state_dict[w]
123+
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
124+
shift = params["bias"] - params["running_mean"] * scale
125+
# Set scale to scale or shift to shift
126+
engine_weight_map[engine_weight_name] = eval(
127+
engine_weight_name.split(" ")[-1].lower()
128+
)
129+
130+
elif sd_weight_name not in state_dict:
131+
# If weights is not in sd, we can leave it unchanged
132+
continue
133+
else:
134+
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name]
135+
136+
engine_weight_map[engine_weight_name] = (
137+
engine_weight_map[engine_weight_name]
138+
.clone()
139+
.reshape(-1)
140+
.contiguous()
141+
.to(torch_dtype),
142+
trt_dtype,
143+
)
144+
145+
return engine_weight_map
146+
147+
111148
def _refit_single_trt_engine_with_gm(
112149
new_gm: torch.fx.GraphModule,
113150
old_engine: trt.ICudaEngine,
114-
input_list: Tuple[Any, ...],
151+
input_list: Sequence[Any],
115152
settings: CompilationSettings = CompilationSettings(),
153+
weight_name_map: Optional[dict[str, List[str]]] = None,
116154
) -> None:
117155
"""
118156
Refit a TensorRT Engine in place
119157
"""
120-
# Get the refitting mapping
121-
mapping = construct_refit_mapping(new_gm, input_list, settings)
158+
122159
refitted = set()
123160

124-
trt_wt_location = trt.TensorLocation.HOST
125161
refitter = trt.Refitter(old_engine, TRT_LOGGER)
126162
weight_list = refitter.get_all_weights()
127163

128-
for layer_name in weight_list:
129-
if layer_name not in mapping:
130-
raise AssertionError(f"{layer_name} is not found in weight mapping")
131-
# Use Numpy to create weights
132-
weight, datatype = mapping[layer_name]
133-
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
134-
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
135-
refitted.add(layer_name)
164+
if weight_name_map:
165+
# Get the refitting mapping
166+
trt_wt_location = trt.TensorLocation.DEVICE
167+
mapping = construct_refit_mapping_from_weight_name_map(
168+
weight_name_map, new_gm.state_dict()
169+
)
170+
for layer_name in weight_list:
171+
if layer_name not in mapping:
172+
logger.warning(f"{layer_name} is not found in weight mapping.")
173+
continue
174+
# Use Numpy to create weights
175+
weight, weight_dtype = mapping[layer_name]
176+
trt_wt_tensor = trt.Weights(
177+
weight_dtype, weight.data_ptr(), torch.numel(weight)
178+
)
179+
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
180+
assert (
181+
len(refitter.get_missing_weights()) == 0
182+
), "Fast refitting failed due to incomplete mapping"
136183

137-
if len(refitted) != len(weight_list):
138-
logger.warning("Not all weights have been refitted!!!")
184+
else:
185+
mapping = construct_refit_mapping(new_gm, input_list, settings)
186+
trt_wt_location = trt.TensorLocation.HOST
187+
for layer_name in weight_list:
188+
if layer_name not in mapping:
189+
raise AssertionError(f"{layer_name} is not found in weight mapping")
190+
# Use Numpy to create weights
191+
weight, datatype = mapping[layer_name]
192+
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
193+
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
194+
refitted.add(layer_name)
195+
196+
if len(refitted) != len(weight_list):
197+
logger.warning("Not all weights have been refitted!!!")
139198

140199
if not refitter.refit_cuda_engine():
141200
logger.error("Error: failed to refit new weights.")
142-
exit(0)
201+
raise AssertionError("Refitting failed.")
143202

144203

145204
def refit_module_weights(
@@ -148,6 +207,8 @@ def refit_module_weights(
148207
arg_inputs: Optional[Tuple[Any, ...]] = None,
149208
kwarg_inputs: Optional[dict[str, Any]] = None,
150209
verify_output: bool = False,
210+
use_weight_map_cache: bool = True,
211+
in_place: bool = False,
151212
) -> torch.fx.GraphModule:
152213
"""
153214
Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine.
@@ -170,7 +231,12 @@ def refit_module_weights(
170231
if len(list(compiled_module.named_children())) == 0:
171232
inline_module = True
172233

173-
compiled_module = copy.deepcopy(compiled_module)
234+
if not in_place:
235+
compiled_module = copy.deepcopy(compiled_module)
236+
elif inline_module:
237+
raise AssertionError(
238+
"Exported program does not support modifying in place. Please set inplace to false and use the returned graph module."
239+
)
174240

175241
# Get the settings and check the setting to be uniform
176242
settings: CompilationSettings = None
@@ -182,13 +248,14 @@ def refit_module_weights(
182248
for name, engine in compiled_module.__dict__.items()
183249
if "engine" in name
184250
]
185-
encoded_settings = compiled_submodules[0][1].__getstate__()[0][
251+
# [('_run_on_acc_0', inline_module)]
252+
encoded_metadata = compiled_submodules[0][1].__getstate__()[0][
186253
SERIALIZED_METADATA_IDX
187254
]
188255
assert (
189-
encoded_settings != ""
190-
), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True."
191-
settings = TorchTensorRTModule.decode_metadata(encoded_settings)
256+
encoded_metadata != ""
257+
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
258+
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"]
192259
# Handle torch modules
193260
compiled_submodules_map = dict(compiled_submodules)
194261
for name, submodule in compiled_module.named_children():
@@ -287,6 +354,7 @@ def refit_module_weights(
287354
# Extract engine from the submodule
288355
try:
289356
if inline_module:
357+
weight_name_map = None
290358
compiled_submodule = compiled_submodules_map[name]
291359
# If this is a torch module, load the old state_dict
292360
if "_run_on_acc" not in name:
@@ -297,8 +365,33 @@ def refit_module_weights(
297365
engine = get_engine_from_encoded_engine(
298366
engine_info[ENGINE_IDX], runtime
299367
)
368+
if use_weight_map_cache:
369+
encoded_metadata = compiled_submodule.__getstate__()[0][
370+
SERIALIZED_METADATA_IDX
371+
]
372+
weight_name_map = TorchTensorRTModule.decode_metadata(
373+
encoded_metadata
374+
)["weight_name_map"]
375+
if not weight_name_map:
376+
use_weight_map_cache = False
377+
logger.warning(
378+
"This engine does not have a weight map cache. Rebuilding the weight map"
379+
)
300380
else:
301381
compiled_submodule = getattr(compiled_module, name)
382+
weight_name_map = None
383+
if use_weight_map_cache:
384+
try:
385+
weight_name_map = compiled_submodule.weight_name_map
386+
except AttributeError:
387+
logger.warning(
388+
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
389+
)
390+
if not weight_name_map:
391+
use_weight_map_cache = False
392+
logger.warning(
393+
"This engine does not have a weight map cache. Rebuilding the weight map"
394+
)
302395
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
303396
engine = compiled_submodule.engine
304397
elif isinstance(compiled_submodule, TorchTensorRTModule):
@@ -335,13 +428,25 @@ def refit_module_weights(
335428
to_torch_device(settings.device),
336429
name,
337430
)
338-
339-
_refit_single_trt_engine_with_gm(
340-
new_gm=new_submodule,
341-
old_engine=engine,
342-
input_list=submodule_inputs,
343-
settings=settings,
344-
)
431+
try:
432+
_refit_single_trt_engine_with_gm(
433+
new_gm=new_submodule,
434+
old_engine=engine,
435+
input_list=submodule_inputs,
436+
settings=settings,
437+
weight_name_map=weight_name_map,
438+
)
439+
except AssertionError as e:
440+
# If fast_refit is used and failed, we fall back to regular refit
441+
logger.warning(e)
442+
if use_weight_map_cache and weight_name_map:
443+
_refit_single_trt_engine_with_gm(
444+
new_gm=new_submodule,
445+
old_engine=engine,
446+
input_list=submodule_inputs,
447+
settings=settings,
448+
weight_name_map=None,
449+
)
345450

346451
if isinstance(compiled_submodule, TorchTensorRTModule):
347452
serialized_engine = bytes(engine.serialize())

0 commit comments

Comments
 (0)