Skip to content

Commit 1ecd7ef

Browse files
cherry pick 3680: fix refit test bug (#3687)
Co-authored-by: cehongwang <[email protected]>
1 parent 7cbf745 commit 1ecd7ef

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def construct_refit_mapping_from_weight_name_map(
101101
params[w.split(".")[-1]] = state_dict[w].cuda()
102102
# Batch norm constant folding
103103

104-
scale, shift = batch_norm_constant_folding(**params, eps=1e-7)
104+
scale, shift = batch_norm_constant_folding(**params, eps=1e-5)
105105
# Set scale to scale or shift to shift
106106
engine_weight_map[engine_weight_name] = eval(
107107
engine_weight_name.split(" ")[-1].lower()

tests/py/dynamo/models/test_model_refit.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,70 @@ def test_mapping():
8989
torch._dynamo.reset()
9090

9191

92+
@unittest.skipIf(
93+
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
94+
"TorchScript Frontend is not available",
95+
)
96+
@unittest.skipIf(
97+
not torch_trt.ENABLED_FEATURES.refit,
98+
"Refit feature is not supported in Python 3.13 or higher",
99+
)
100+
@unittest.skipIf(
101+
not importlib.util.find_spec("torchvision"),
102+
"torchvision is not installed",
103+
)
104+
@pytest.mark.unit
105+
def test_conv_refit_with_weightmap():
106+
class net(nn.Module):
107+
def __init__(self):
108+
super().__init__()
109+
self.conv = nn.Conv2d(3, 3, 1)
110+
111+
def forward(self, x):
112+
return self.conv(x)
113+
114+
model = net().eval().to("cuda")
115+
model2 = net().eval().to("cuda")
116+
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
117+
enabled_precisions = {torch.float}
118+
min_block_size = 1
119+
use_python_runtime = True
120+
121+
exp_program = torch.export.export(model, tuple(inputs))
122+
exp_program2 = torch.export.export(model2, tuple(inputs))
123+
124+
trt_gm = torchtrt.dynamo.compile(
125+
exp_program,
126+
tuple(inputs),
127+
use_python_runtime=use_python_runtime,
128+
enabled_precisions=enabled_precisions,
129+
min_block_size=min_block_size,
130+
immutable_weights=False,
131+
)
132+
133+
new_trt_gm = refit_module_weights(
134+
compiled_module=trt_gm,
135+
new_weight_module=exp_program2,
136+
arg_inputs=inputs,
137+
use_weight_map_cache=True,
138+
verify_output=True,
139+
)
140+
141+
# Check the output
142+
model2.to("cuda")
143+
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
144+
*inputs
145+
)
146+
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
147+
assertions.assertTrue(
148+
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
149+
"Refit Result is not correct. Refit failed",
150+
)
151+
# Clean up model env
152+
153+
torch._dynamo.reset()
154+
155+
92156
@unittest.skipIf(
93157
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
94158
"TorchScript Frontend is not available",

tests/py/dynamo/runtime/test_mutable_torchtrt_module.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,7 @@ def test_resnet18_modify_attribute():
317317
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec)
318318
mutable_module(*inputs)
319319

320-
mutable_module.conv1.weight = nn.Parameter(
321-
torch.rand_like(mutable_module.conv1.weight)
322-
)
320+
mutable_module.fc.weight = nn.Parameter(torch.rand_like(mutable_module.fc.weight))
323321
assertions.assertEqual(
324322
mutable_module.refit_state.get_state(),
325323
RefitFlag.UNKNOWN,

0 commit comments

Comments
 (0)