Skip to content

Commit 665631e

Browse files
author
pytorchbot
committed
2025-11-22 nightly release (ac478ea)
1 parent cc215c0 commit 665631e

File tree

11 files changed

+234
-39
lines changed

11 files changed

+234
-39
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def save(
606606
inputs: Optional[Sequence[torch.Tensor]] = None,
607607
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
608608
kwarg_inputs: Optional[dict[str, Any]] = None,
609-
retrace: bool = False,
609+
retrace: bool = True,
610610
pickle_protocol: int = 2,
611611
**kwargs: Any,
612612
) -> None:
@@ -661,7 +661,7 @@ def save(
661661
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
662662
)
663663
elif module_type == _ModuleType.ts:
664-
if not all([output_format == f for f in ["exported_program", "aot_inductor"]]):
664+
if not all(output_format == f for f in ["exported_program", "aot_inductor"]):
665665
raise ValueError(
666666
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
667667
)

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs
2929
from torch_tensorrt.dynamo.lowering import (
30+
clean_up_graph_after_modifications,
3031
get_decompositions,
3132
post_lowering,
3233
pre_export_lowering,
@@ -94,6 +95,8 @@ def construct_refit_mapping_from_weight_name_map(
9495
engine_weight_map = {}
9596
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
9697
# Add more constant folding converters here
98+
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
99+
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
97100
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
98101
# Batch Norm Layer
99102
params = {}
@@ -106,12 +109,12 @@ def construct_refit_mapping_from_weight_name_map(
106109
engine_weight_map[engine_weight_name] = eval(
107110
engine_weight_name.split(" ")[-1].lower()
108111
)
112+
109113
elif sd_weight_name not in state_dict:
110114
# If weights is not in sd, we can leave it unchanged
111115
continue
112116
else:
113-
trt_dtype = dtype._from(np_weight_type).to(trt.DataType)
114-
torch_dtype = dtype._from(np_weight_type).to(torch.dtype)
117+
115118
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to(
116119
to_torch_device(settings.device)
117120
)
@@ -272,12 +275,66 @@ def refit_module_weights(
272275
compiled_submodules_map[name] = submodule
273276

274277
else:
278+
# Handle torch modules
279+
compiled_submodules_map = {}
280+
guard_fn_modules = []
275281
for name, submodule in compiled_module.named_children():
276-
if not isinstance(
277-
submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)
282+
if (
283+
not isinstance(
284+
submodule,
285+
(
286+
PythonTorchTensorRTModule,
287+
TorchTensorRTModule,
288+
torch.nn.modules.module.Module,
289+
),
290+
)
291+
or "_run_on_gpu" in name
278292
):
279293
continue
280-
settings = submodule.settings
294+
295+
# When we re-export the graph module, torch.export._unlift.GuardsFn modules are being added as submodules.
296+
if isinstance(submodule, torch.export._unlift.GuardsFn):
297+
guard_fn_modules.append(name)
298+
continue
299+
# Obtain the settings
300+
301+
compiled_submodules = [
302+
(name.replace("_engine", ""), engine)
303+
for name, engine in submodule.__dict__.items()
304+
if "engine" in name
305+
]
306+
307+
settings = None
308+
try:
309+
# If the gm is not inlined or transformed by retracing, the settings is stored in the submodule
310+
settings = submodule.settings
311+
except AttributeError:
312+
313+
encoded_metadata = [
314+
engine for name, engine in compiled_submodules if name == "engine"
315+
][0].__getstate__()[0][SERIALIZED_METADATA_IDX]
316+
assert (
317+
encoded_metadata != ""
318+
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version"
319+
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)[
320+
"settings"
321+
]
322+
323+
compiled_submodules_map[name] = submodule
324+
325+
# Delete the guard fn modules to avoid the guard fn modules being refitted
326+
# First, remove nodes in the graph that reference the guard function modules
327+
for node in list(compiled_module.graph.nodes):
328+
if node.op == "call_module" and node.target in guard_fn_modules:
329+
compiled_module.graph.erase_node(node)
330+
331+
# Now delete the submodules themselves
332+
for guard_fn_module_name in guard_fn_modules:
333+
# delattr(compiled_module, guard_fn_module_name)
334+
compiled_module.delete_submodule(guard_fn_module_name)
335+
336+
# Clean up the graph
337+
clean_up_graph_after_modifications(compiled_module)
281338

282339
assert settings is not None
283340

@@ -411,11 +468,29 @@ def refit_module_weights(
411468
)
412469
else:
413470
compiled_submodule = getattr(compiled_module, name)
471+
if "_run_on_acc" not in name:
472+
compiled_submodule.load_state_dict(new_submodule.state_dict())
473+
continue
474+
414475
weight_name_map = None
415476
if use_weight_map_cache:
416477
try:
417478
weight_name_map = compiled_submodule.weight_name_map
418479
except AttributeError:
480+
if isinstance(compiled_submodule, torch.nn.Module):
481+
# Torch retrace module
482+
assert (
483+
not settings.use_python_runtime
484+
), "Refitting a torch retraced module is only supported with use_python_runtime=False"
485+
encoded_metadata = [
486+
engine
487+
for name, engine in compiled_submodules
488+
if name == "engine"
489+
][0].__getstate__()[0][SERIALIZED_METADATA_IDX]
490+
weight_name_map = TorchTensorRTModule.decode_metadata(
491+
encoded_metadata
492+
)["weight_name_map"]
493+
419494
if not isinstance(
420495
compiled_submodule, torch.fx.graph_module.GraphModule
421496
):
@@ -427,21 +502,16 @@ def refit_module_weights(
427502
logger.warning(
428503
"This engine does not have a weight map cache. Rebuilding the weight map"
429504
)
430-
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
505+
506+
# Rexporting the TRT compiled graph module and loading it back doesn't preserve the instance type and registers
507+
# the compiled submodule as torch.nn.Module. So we use settings.use_python_runtime to determine the instance type.
508+
if settings.use_python_runtime:
431509
engine = compiled_submodule.engine
432-
elif isinstance(compiled_submodule, TorchTensorRTModule):
510+
else:
433511
engine_info = compiled_submodule.engine.__getstate__()[0]
434512
engine = get_engine_from_encoded_engine(
435513
engine_info[ENGINE_IDX], runtime
436514
)
437-
elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule):
438-
# This is graph break resulted by unsupported ops
439-
compiled_submodule.load_state_dict(new_submodule.state_dict())
440-
continue
441-
else:
442-
raise AssertionError(
443-
"The type of graph module is not supported for refitting."
444-
)
445515
except AttributeError:
446516
raise AssertionError(
447517
"The type of graph module is not supported for refitting or two compiled modules do not match."
@@ -500,7 +570,12 @@ def refit_module_weights(
500570
new_engine_info[ENGINE_IDX] = bytes(serialized_engine)
501571
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
502572
setattr(compiled_module, f"{name}_engine", refitted_engine)
503-
573+
elif isinstance(compiled_submodule, torch.nn.Module):
574+
# Torch retrace module
575+
new_engine_info = list(engine_info)
576+
new_engine_info[ENGINE_IDX] = bytes(serialized_engine)
577+
refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info))
578+
compiled_submodule.engine = refitted_engine
504579
del engine
505580
gc.collect()
506581
torch.cuda.empty_cache()

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,6 +2026,7 @@ def aten_ops_sub(
20262026
)
20272027

20282028

2029+
@dynamo_tensorrt_converter(operator.truediv, supports_dynamic_shapes=True)
20292030
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor, supports_dynamic_shapes=True)
20302031
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode, supports_dynamic_shapes=True)
20312032
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar, supports_dynamic_shapes=True)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from ._aten_lowering_pass import *
2+
from .pass_utils import clean_up_graph_after_modifications
23
from .remove_sym_nodes import remove_sym_nodes
34
from .repair_input_aliasing import repair_input_aliasing

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
102102
is_shape_tensor=True,
103103
)
104104
)
105+
elif isinstance(input_meta, torch.SymFloat):
106+
torchtrt_inputs.append(
107+
get_input(
108+
[1],
109+
torch.float32,
110+
name=input.name,
111+
is_shape_tensor=False, # Only SymInt inputs are treated as shape tensors
112+
)
113+
)
105114
else:
106115
raise ValueError(
107116
f"The meta val for input node {input.target} is of type : {type(input_meta)}. Supported types: torch.Tensor|FakeTensor|torch.SymInt"

py/torch_tensorrt/dynamo/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,9 @@ def unwrap_tensor_shape(
455455
tensor_shape.append(min_max_opt[mode])
456456
else:
457457
tensor_shape.append((min_max_opt["min"], min_max_opt["max"]))
458+
elif isinstance(tensor, torch.SymFloat):
459+
# SymFloats can be an input to graph sometimes. Although SymFloat is scalar value, we treat it as a 1D tensor throughout Torch-TRT codebase.
460+
tensor_shape.append(1)
458461
elif isinstance(tensor, (torch.Tensor, FakeTensor)):
459462
for dimension in tensor.shape:
460463
tensor_shape.extend(unwrap_tensor_shape(dimension, mode=mode))
@@ -472,6 +475,8 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
472475
return torch.tensor(tensor).dtype
473476
elif isinstance(tensor, torch.SymInt):
474477
return torch.int64
478+
elif isinstance(tensor, torch.SymFloat):
479+
return torch.float32
475480
elif tensor is None:
476481
# Case where we explicitly pass one of the inputs to be None (eg: FLUX.1-dev)
477482
return None

tests/py/dynamo/models/test_export_kwargs_serde.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def forward(self, x, b=5, c=None, d=None):
7676

7777
# Save the module
7878
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
79-
torchtrt.save(trt_gm, trt_ep_path)
79+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
8080
# Clean up model env
8181
torch._dynamo.reset()
8282

@@ -138,7 +138,7 @@ def forward(self, x, b=5, c=None, d=None):
138138

139139
# Save the module
140140
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
141-
torchtrt.save(trt_gm, trt_ep_path)
141+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
142142
# Clean up model env
143143
torch._dynamo.reset()
144144

@@ -209,7 +209,7 @@ def forward(self, x, b=5, c=None, d=None):
209209

210210
# Save the module
211211
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
212-
torchtrt.save(trt_gm, trt_ep_path)
212+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
213213
# Clean up model env
214214
torch._dynamo.reset()
215215

@@ -299,7 +299,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
299299
)
300300
# Save the module
301301
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
302-
torchtrt.save(trt_gm, trt_ep_path)
302+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
303303
# Clean up model env
304304
torch._dynamo.reset()
305305

@@ -389,7 +389,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
389389
)
390390
# Save the module
391391
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
392-
torchtrt.save(trt_gm, trt_ep_path)
392+
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
393393
# Clean up model env
394394
torch._dynamo.reset()
395395

tests/py/dynamo/models/test_export_serde.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(self, x):
5656

5757
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
5858
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
59-
torchtrt.save(trt_module, trt_ep_path)
59+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
6060

6161
deser_trt_module = torchtrt.load(trt_ep_path).module()
6262
# Check Pyt and TRT exported program outputs
@@ -111,7 +111,7 @@ def forward(self, x):
111111

112112
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
113113
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
114-
torchtrt.save(trt_module, trt_ep_path)
114+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
115115

116116
deser_trt_module = torchtrt.load(trt_ep_path).module()
117117
# Check Pyt and TRT exported program outputs
@@ -170,7 +170,7 @@ def forward(self, x):
170170

171171
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
172172
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
173-
torchtrt.save(trt_module, trt_ep_path)
173+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
174174

175175
deser_trt_module = torchtrt.load(trt_ep_path).module()
176176
# Check Pyt and TRT exported program outputs
@@ -232,7 +232,7 @@ def forward(self, x):
232232

233233
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
234234
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
235-
torchtrt.save(trt_module, trt_ep_path)
235+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
236236

237237
deser_trt_module = torchtrt.load(trt_ep_path).module()
238238
outputs_pyt = model(input)
@@ -279,7 +279,7 @@ def test_resnet18(ir):
279279

280280
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
281281
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
282-
torchtrt.save(trt_module, trt_ep_path)
282+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
283283

284284
deser_trt_module = torchtrt.load(trt_ep_path).module()
285285
outputs_pyt = model(input)
@@ -331,7 +331,7 @@ def test_resnet18_cpu_offload(ir):
331331
msg="Model should be offloaded to CPU",
332332
)
333333
model.cuda()
334-
torchtrt.save(trt_module, trt_ep_path)
334+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
335335

336336
deser_trt_module = torchtrt.load(trt_ep_path).module()
337337
outputs_pyt = model(input)
@@ -380,7 +380,7 @@ def test_resnet18_dynamic(ir):
380380

381381
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
382382
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
383-
torchtrt.save(trt_module, trt_ep_path)
383+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
384384
# TODO: Enable this serialization issues are fixed
385385
# deser_trt_module = torchtrt.load(trt_ep_path).module()
386386
outputs_pyt = model(input)
@@ -413,7 +413,7 @@ def test_resnet18_torch_exec_ops_serde(ir):
413413

414414
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
415415
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
416-
torchtrt.save(trt_module, trt_ep_path)
416+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
417417
deser_trt_module = torchtrt.load(trt_ep_path).module()
418418
outputs_pyt = deser_trt_module(input)
419419
outputs_trt = trt_module(input)
@@ -463,7 +463,7 @@ def forward(self, x):
463463
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
464464
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
465465

466-
torchtrt.save(trt_module, trt_ep_path)
466+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
467467

468468
deser_trt_module = torchtrt.load(trt_ep_path).module()
469469
outputs_pyt = model(input)
@@ -525,7 +525,7 @@ def forward(self, x):
525525
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
526526
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
527527
model.cuda()
528-
torchtrt.save(trt_module, trt_ep_path)
528+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
529529

530530
deser_trt_module = torchtrt.load(trt_ep_path).module()
531531
outputs_pyt = model(input)
@@ -584,7 +584,7 @@ def forward(self, x):
584584
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
585585
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
586586

587-
torchtrt.save(trt_module, trt_ep_path)
587+
torchtrt.save(trt_module, trt_ep_path, retrace=False)
588588

589589
deser_trt_module = torchtrt.load(trt_ep_path).module()
590590
outputs_pyt = model(input)

0 commit comments

Comments
 (0)