Skip to content

Commit 60ec67b

Browse files
cehongwangperi044
andauthored
Refit bug fix (#3097)
Co-authored-by: Dheeraj Peri <[email protected]>
1 parent 6180836 commit 60ec67b

14 files changed

+231
-172
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"make_refitable": True,
3535
}
3636

37-
model = models.resnet18(pretrained=False).eval().to("cuda")
37+
model = models.resnet18(pretrained=True).eval().to("cuda")
3838
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
3939
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
4040
mutable_module(*inputs)
@@ -45,7 +45,7 @@
4545

4646
# %%
4747
# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
48-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
48+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
4949
mutable_module.load_state_dict(model2.state_dict())
5050

5151

examples/dynamo/refit_engine_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
# Compile the module for the first time and save it.
4040
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4141

42-
model = models.resnet18(pretrained=False).eval().to("cuda")
42+
model = models.resnet18(pretrained=True).eval().to("cuda")
4343
exp_program = torch.export.export(model, tuple(inputs))
4444
enabled_precisions = {torch.float}
4545
debug = False
@@ -68,7 +68,7 @@
6868
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6969

7070
# Create and compile the updated model
71-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
71+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
7272
exp_program2 = torch.export.export(model2, tuple(inputs))
7373

7474

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ def compile(
5656
disable_tf32: bool = _defaults.DISABLE_TF32,
5757
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
5858
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
59-
enabled_precisions: (
60-
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
61-
) = _defaults.ENABLED_PRECISIONS,
59+
enabled_precisions: Union[
60+
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
61+
] = _defaults.ENABLED_PRECISIONS,
6262
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
6363
make_refitable: bool = _defaults.MAKE_REFITABLE,
6464
debug: bool = _defaults.DEBUG,

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
TorchTensorRTModule,
3535
)
3636
from torch_tensorrt.dynamo.utils import (
37-
check_output,
37+
check_module_output,
3838
get_torch_inputs,
3939
set_log_level,
4040
to_torch_device,
@@ -115,19 +115,8 @@ def construct_refit_mapping_from_weight_name_map(
115115
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
116116
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
117117
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
118-
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
119-
# Batch Norm Layer
120-
params = {}
121-
for w in sd_weight_name:
122-
params[w.split(".")[-1]] = state_dict[w]
123-
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
124-
shift = params["bias"] - params["running_mean"] * scale
125-
# Set scale to scale or shift to shift
126-
engine_weight_map[engine_weight_name] = eval(
127-
engine_weight_name.split(" ")[-1].lower()
128-
)
129118

130-
elif sd_weight_name not in state_dict:
119+
if sd_weight_name not in state_dict:
131120
# If weights is not in sd, we can leave it unchanged
132121
continue
133122
else:
@@ -157,16 +146,25 @@ def _refit_single_trt_engine_with_gm(
157146
"""
158147

159148
refitted = set()
160-
149+
torch_device = list(new_gm.state_dict().values())[0].device.type
161150
refitter = trt.Refitter(old_engine, TRT_LOGGER)
162151
weight_list = refitter.get_all_weights()
163152

164153
if weight_name_map:
165154
# Get the refitting mapping
166-
trt_wt_location = trt.TensorLocation.DEVICE
155+
trt_wt_location = (
156+
trt.TensorLocation.DEVICE
157+
if torch_device == "cuda"
158+
else trt.TensorLocation.HOST
159+
)
167160
mapping = construct_refit_mapping_from_weight_name_map(
168161
weight_name_map, new_gm.state_dict()
169162
)
163+
164+
# Debug Use
165+
# correct = construct_refit_mapping(new_gm, input_list, settings)
166+
# comparison = {k: (np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2), correct[k][0], mapping[k][0]) for k in mapping if k in correct}
167+
170168
for layer_name in weight_list:
171169
if layer_name not in mapping:
172170
logger.warning(f"{layer_name} is not found in weight mapping.")
@@ -235,7 +233,7 @@ def refit_module_weights(
235233
compiled_module = copy.deepcopy(compiled_module)
236234
elif inline_module:
237235
raise AssertionError(
238-
"Exported program does not support modifying in place. Please set inplace to false and use the returned graph module."
236+
"Exported program does not support modifying in place. Please set in_place to false and use the returned graph module."
239237
)
240238

241239
# Get the settings and check the setting to be uniform
@@ -283,6 +281,7 @@ def refit_module_weights(
283281
arg_inputs = [arg_inputs]
284282
torch_inputs = get_torch_inputs(arg_inputs, device)
285283

284+
torch_kwarg_inputs: Any = {}
286285
if kwarg_inputs:
287286
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
288287
runtime = trt.Runtime(TRT_LOGGER)
@@ -436,6 +435,7 @@ def refit_module_weights(
436435
settings=settings,
437436
weight_name_map=weight_name_map,
438437
)
438+
439439
except AssertionError as e:
440440
# If fast_refit is used and failed, we fall back to regular refit
441441
logger.warning(e)
@@ -463,14 +463,27 @@ def refit_module_weights(
463463
setattr(compiled_module, f"{name}_engine", refitted_engine)
464464

465465
if verify_output and arg_inputs is not None:
466-
if check_output(
466+
if check_module_output(
467467
new_module=new_gm,
468468
refitted_module=compiled_module,
469469
arg_inputs=torch_inputs,
470470
kwarg_inputs=torch_kwarg_inputs,
471471
):
472472
logger.info("Refitting Succeed!")
473473
else:
474+
if weight_name_map:
475+
logger.warning(
476+
"Refitting with weight_name_map yielded incorrect result! The outputs do not match."
477+
)
478+
return refit_module_weights(
479+
compiled_module,
480+
new_weight_module,
481+
arg_inputs,
482+
kwarg_inputs,
483+
verify_output,
484+
use_weight_map_cache=False,
485+
in_place=in_place,
486+
)
474487
logger.error("Refitting Failed! The outputs do not match.")
475488
else:
476489
logger.info("Refitting Completed! Output verification skipped.")

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
import gc
12
import io
23
import logging
34
import os
45
import warnings
56
from datetime import datetime
6-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
7+
from typing import (
8+
Any,
9+
Callable,
10+
Dict,
11+
List,
12+
NamedTuple,
13+
Optional,
14+
Sequence,
15+
Set,
16+
Tuple,
17+
Union,
18+
)
719

820
import numpy as np
921
import torch
@@ -26,7 +38,7 @@
2638
get_node_name,
2739
get_trt_tensor,
2840
)
29-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
41+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
3042
from torch_tensorrt.fx.observer import Observer
3143
from torch_tensorrt.logging import TRT_LOGGER
3244

@@ -327,6 +339,39 @@ def _construct_trt_network_def(self) -> None:
327339
f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}"
328340
)
329341

342+
@staticmethod
343+
def find_weight(
344+
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
345+
) -> str:
346+
"""
347+
We need to build map from engine weight name to state_dict weight name.
348+
The purpose of this function is to find the corresponding weight name in module state_dict.
349+
350+
weight_name: the target weight name we want to search for
351+
np_map: the map from weight name to np values in INetworkDefinition
352+
state_dict: state of the graph module
353+
"""
354+
network_weight = np_map[weight_name]
355+
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
356+
for sd_w_name, sd_weight in state_dict.items():
357+
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
358+
del state_dict[sd_w_name]
359+
return sd_w_name
360+
return ""
361+
362+
@staticmethod
363+
def check_weight_equal(
364+
sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray]
365+
) -> Any:
366+
if not isinstance(network_weight, torch.Tensor):
367+
network_weight = torch.from_numpy(network_weight).cuda()
368+
try:
369+
return sd_weight.shape == network_weight.shape and torch.all(
370+
torch.abs(sd_weight - network_weight) < 0.01
371+
)
372+
except Exception:
373+
return torch.all(sd_weight == network_weight)
374+
330375
def _save_weight_mapping(self) -> None:
331376
"""
332377
Construct the weight name mapping from engine weight name to state_dict weight name.
@@ -336,23 +381,6 @@ def _save_weight_mapping(self) -> None:
336381
2. Value mapping that, for each weight in INetworkDefinition search for identical weight in state_dict
337382
"""
338383

339-
def find_weight(
340-
weight_name: str, np_map: dict[str, Any], sd: dict[str, Any]
341-
) -> str:
342-
network_weight = np_map[weight_name]
343-
for sd_w_name, sd_weight in sd.items():
344-
if check_weight_equal(sd_weight, network_weight):
345-
return sd_w_name
346-
return ""
347-
348-
def check_weight_equal(
349-
sd_weight: torch.tensor, network_weight: np.ndarray
350-
) -> Any:
351-
sd_weight = sd_weight.reshape(-1).cpu().numpy()
352-
return sd_weight.size == network_weight.size and np.allclose(
353-
sd_weight, network_weight, 1e-1, 1e-1
354-
)
355-
356384
MODULE_MAP = {
357385
"SCALE": (
358386
trt.IScaleLayer,
@@ -398,8 +426,19 @@ def check_weight_equal(
398426
)
399427
}
400428
"""
429+
_LOGGER.info("Building weight name mapping...")
401430
# Stage 1: Name mapping
402431
sd = self.module.state_dict()
432+
torch_device = to_torch_device(self.compilation_settings.device)
433+
gm_is_on_cuda = list(sd.values())[0].device.type == "cuda"
434+
if not gm_is_on_cuda:
435+
# If the model original position is on CPU, move it GPU
436+
sd = {
437+
k: v.reshape(-1).to(torch_device)
438+
for k, v in self.module.state_dict().items()
439+
}
440+
else:
441+
sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()}
403442
weight_name_map: dict[str, Any] = {}
404443
np_map = {}
405444
net = self.ctx.net
@@ -448,10 +487,10 @@ def check_weight_equal(
448487
if "SCALE" in engine_weight_name:
449488
# There is no direct connection in batch_norm layer. So skip it
450489
pass
451-
elif sd_weight_name not in sd or not check_weight_equal(
490+
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
452491
sd[sd_weight_name], np_map[engine_weight_name]
453492
):
454-
weight_name_map[engine_weight_name] = find_weight(
493+
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
455494
engine_weight_name, np_map, sd
456495
)
457496

@@ -462,6 +501,10 @@ def check_weight_equal(
462501

463502
self.weight_name_map = weight_name_map
464503

504+
del np_map, sd
505+
gc.collect()
506+
torch.cuda.empty_cache()
507+
465508
def run(
466509
self,
467510
strict_type_constraints: bool = False,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,13 @@ def convert_module(
130130
from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm
131131
from torch_tensorrt.logging import TRT_LOGGER
132132

133-
runtime = trt.Runtime(TRT_LOGGER)
134-
refit_test_engine = runtime.deserialize_cuda_engine(
135-
interpreter_result.serialized_engine
136-
)
137133
weight_name_map: Any = None
138134
# Do the test refit with cached map if make_refitable is enabled
139135
if settings.make_refitable:
140-
weight_name_map = interpreter_result.weight_name_map
136+
runtime = trt.Runtime(TRT_LOGGER)
137+
refit_test_engine = runtime.deserialize_cuda_engine(
138+
interpreter_result.serialized_engine
139+
)
141140
try:
142141
_refit_single_trt_engine_with_gm(
143142
new_gm=module,
@@ -146,9 +145,13 @@ def convert_module(
146145
settings=settings,
147146
weight_name_map=interpreter_result.weight_name_map,
148147
)
148+
weight_name_map = interpreter_result.weight_name_map
149149
except AssertionError:
150150
logger.warning("Fast refit test failed. Removing the weight map caching.")
151151

152+
del refit_test_engine
153+
torch.cuda.empty_cache()
154+
152155
rt_cls = PythonTorchTensorRTModule
153156

154157
if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:

0 commit comments

Comments
 (0)