Skip to content

Commit c4a5c3d

Browse files
committed
fix: we now store the traced symbolic functions from compile time in the metadata to use in the case of reexport. Also removes the need to access the real tensorrt engine during reexport
1 parent 99d29d0 commit c4a5c3d

File tree

12 files changed

+294
-299
lines changed

12 files changed

+294
-299
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@ std::string TRTEngine::get_engine_layer_info() {
325325
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
326326
}
327327

328+
std::string TRTEngine::get_serialized_metadata() {
329+
return this->serialized_metadata;
330+
}
331+
328332
std::vector<at::Tensor> TRTEngine::infer_outputs(std::vector<std::vector<int64_t>> input_shapes) {
329333
std::vector<at::Tensor> outputs;
330334
TORCHTRT_CHECK(

core/runtime/TRTEngine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ struct TRTEngine : torch::CustomClassHolder {
158158
void set_profile_format(std::string profile_format);
159159
void disable_profiling();
160160
std::string get_engine_layer_info();
161+
std::string get_serialized_metadata();
161162

162163
void dump_engine_layer_info_to_file(const std::string& path);
163164
void dump_engine_layer_info();

core/runtime/register_jit_hooks.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8888
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
8989
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
9090
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
91+
.def("get_serialized_metadata", &TRTEngine::get_serialized_metadata)
9192
.def("infer_outputs", &TRTEngine::infer_outputs)
9293
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
9394
.def("set_output_tensors_as_unowned", &TRTEngine::set_output_tensors_as_unowned)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torch.export import ExportedProgram
1212
from torch.fx.node import Target
13+
1314
from torch_tensorrt._Device import Device
1415
from torch_tensorrt._enums import EngineCapability, dtype
1516
from torch_tensorrt._features import needs_cross_compile
@@ -564,7 +565,7 @@ def compile(
564565

565566
if not kwargs.get("use_explicit_typing", False):
566567
warnings.warn(
567-
"`use_explicit_typing` is deprecated. This setting will be removed and you should enable autocast instead.",
568+
"`use_explicit_typing` is deprecated. use_explicit_types is now on by default, this setting will be removed and you should enable autocast to recover weak typing behavior.",
568569
DeprecationWarning,
569570
stacklevel=2,
570571
)
@@ -1042,7 +1043,6 @@ def preserve_module_specs(
10421043
trt_modules[name] = trt_module
10431044

10441045
if _debugger_config:
1045-
10461046
if _debugger_config.save_engine_profile:
10471047
if settings.use_python_runtime:
10481048
if _debugger_config.profile_format != "cudagraph":

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
OutputSpec,
1818
TensorArgument,
1919
)
20+
2021
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX, NAME_IDX
2122

2223

@@ -270,7 +271,26 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
270271
gm.graph.erase_node(gm_added_placeholder_inputs[idx])
271272

272273
# Replace the pytorch submodule node (call_module) with the inlined subgraph output
273-
gm_node.replace_all_uses_with(submodule_output)
274+
# Special handling when submodule returns multiple outputs (tuple)
275+
if isinstance(submodule_output, tuple):
276+
# The fallback module has multiple outputs
277+
# Find getitem nodes that extract from this module call and replace them directly
278+
getitem_users = [
279+
user
280+
for user in list(gm_node.users.keys())
281+
if user.op == "call_function"
282+
and user.target is operator.getitem
283+
]
284+
for user in getitem_users:
285+
# getitem extracts element idx from the tuple
286+
_, idx = user.args
287+
# Replace this getitem with the actual node from the tuple
288+
user.replace_all_uses_with(submodule_output[idx])
289+
# Erase the getitem node since it's no longer needed
290+
gm.graph.erase_node(user)
291+
else:
292+
# Single output - normal replacement
293+
gm_node.replace_all_uses_with(submodule_output)
274294

275295
# copy the attributes of the submodule into gm (graph_copy doesn't do this)
276296
copy_submodule_attributes(gm, submodule, gm_node.name)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22

33
import io
44
import logging
5-
from typing import Any, List, NamedTuple, Optional, Sequence
5+
from typing import Any, Dict, List, NamedTuple, Optional, Sequence
66

7-
import tensorrt as trt
87
import torch
9-
108
from torch_tensorrt._enums import dtype
119
from torch_tensorrt._features import ENABLED_FEATURES
1210
from torch_tensorrt._Input import Input
@@ -27,6 +25,8 @@
2725
)
2826
from torch_tensorrt.logging import TRT_LOGGER
2927

28+
import tensorrt as trt
29+
3030
logger = logging.getLogger(__name__)
3131

3232

@@ -36,9 +36,7 @@ class SerializedInterpreterResult(NamedTuple):
3636
output_names: Sequence[str]
3737
weight_name_map: Optional[dict[Any, Any]]
3838
requires_output_allocator: bool
39-
symbolic_shape_expressions: Optional[
40-
str
41-
] # Base64-encoded serialized symbolic shape mapping
39+
symbolic_shape_expressions: List[Dict[str, Any]]
4240

4341

4442
def infer_module_output_dtypes(
@@ -108,6 +106,7 @@ def pull_cached_engine(
108106
engine_cache: BaseEngineCache,
109107
settings: CompilationSettings,
110108
inputs: Sequence[Input],
109+
symbolic_shape_expressions: List[Dict[str, Any]],
111110
) -> Optional[SerializedInterpreterResult]:
112111
if hash_val is None:
113112
logger.warning(
@@ -137,16 +136,16 @@ def pull_cached_engine(
137136
setting_compatiblity, incompattible_settings = settings_are_compatible(
138137
settings, cached_engine_compilation_settings
139138
)
140-
assert setting_compatiblity, (
141-
f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {cached_engine_compilation_settings}, new_settings: {settings})"
142-
)
139+
assert (
140+
setting_compatiblity
141+
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {cached_engine_compilation_settings}, new_settings: {settings})"
143142

144143
for i, e in enumerate(
145144
[Input.equivalent_spec(c, i) for c, i in zip(cached_engine_inputs, inputs)]
146145
):
147-
assert e, (
148-
f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}"
149-
)
146+
assert (
147+
e
148+
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}"
150149

151150
logger.info(
152151
f"Found the cached engine with hash {hash_val} that corresponds to this graph. It is directly loaded."
@@ -190,6 +189,7 @@ def pull_cached_engine(
190189
output_names=output_names,
191190
weight_name_map=weight_name_map,
192191
requires_output_allocator=requires_output_allocator,
192+
symbolic_shape_expressions=symbolic_shape_expressions,
193193
)
194194
return None
195195

@@ -210,6 +210,12 @@ def interpret_module_to_result(
210210
SerializedInterpreterResult
211211
"""
212212

213+
symbolic_shape_expressions = extract_symbolic_shape_expressions(module)
214+
if symbolic_shape_expressions is None:
215+
raise RuntimeError(
216+
"Failed to extract symbolic shape expressions from source FX graph partition"
217+
)
218+
213219
# engine_cache could be None if:
214220
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
215221
# 2) both cache_built_engines and reuse_cached_engines are False
@@ -242,7 +248,12 @@ def interpret_module_to_result(
242248
)
243249
else:
244250
serialized_interpreter_result = pull_cached_engine(
245-
hash_val, module, engine_cache, settings, inputs
251+
hash_val,
252+
module,
253+
engine_cache,
254+
settings,
255+
inputs,
256+
symbolic_shape_expressions,
246257
)
247258
if serialized_interpreter_result is not None: # hit the cache
248259
return serialized_interpreter_result
@@ -251,11 +262,8 @@ def interpret_module_to_result(
251262
module, truncate_double=settings.truncate_double
252263
)
253264

254-
# Extract symbolic shape expressions before interpretation
255-
# This captures the symbolic relationship between input and output shapes
256-
symbolic_shape_expressions = extract_symbolic_shape_expressions(module)
257265
logger.debug(
258-
f"Extracted symbolic shape expressions: {len(symbolic_shape_expressions) if symbolic_shape_expressions else 0} bytes"
266+
f"Extracted symbolic shape expressions: {len(symbolic_shape_expressions) if symbolic_shape_expressions else 0} outputs"
259267
)
260268

261269
interpreter = TRTInterpreter(

py/torch_tensorrt/dynamo/conversion/_symbolic_shape_capture.py

Lines changed: 40 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,29 @@
66
output shapes without pattern matching.
77
"""
88

9-
import base64
10-
import pickle
11-
from typing import Any, Dict, List, Optional, Tuple
9+
import logging
10+
from typing import Any, Dict, List, Optional
1211

1312
import torch
1413

14+
logger = logging.getLogger(__name__)
15+
1516

1617
def extract_symbolic_shape_expressions(
1718
module: torch.fx.GraphModule,
18-
) -> Optional[str]:
19+
) -> Optional[List[Dict[str, Any]]]:
1920
"""
20-
Extract symbolic shape expressions from an FX graph and serialize them.
21+
Extract symbolic shape expressions from an FX graph.
2122
22-
This captures the relationship between input placeholder shapes and output shapes,
23-
storing the symbolic expressions (as sympy expressions) that can be deserialized
24-
and evaluated in the meta kernel.
23+
This captures the symbolic expressions (as sympy expressions) for output shapes
24+
that can be applied to input fake tensors at runtime.
2525
2626
Args:
2727
module: FX GraphModule with symbolic shapes in node metadata
2828
2929
Returns:
30-
Base64-encoded serialized mapping of {output_idx: {dim_idx: sympy_expr}} or None if no symbolic shapes
30+
List of dicts containing shape_exprs and dtype for each output, or None if extraction fails
3131
"""
32-
# Find input placeholders (excluding parameters/buffers)
33-
input_placeholders = []
34-
for node in module.graph.nodes:
35-
if node.op == "placeholder" and "val" in node.meta:
36-
val = node.meta["val"]
37-
# Skip parameters and buffers (they're also placeholders but not inputs)
38-
if isinstance(val, torch.Tensor) and not node.name.startswith(
39-
"_frozen_param"
40-
):
41-
# Check if this is an actual input (not a parameter)
42-
# Parameters typically have names like "p_weight", "p_bias"
43-
if not node.name.startswith("p_"):
44-
input_placeholders.append(node)
45-
4632
# Find output node
4733
output_nodes = [node for node in module.graph.nodes if node.op == "output"]
4834
if not output_nodes:
@@ -55,59 +41,37 @@ def extract_symbolic_shape_expressions(
5541
if not isinstance(output_args, (tuple, list)):
5642
output_args = (output_args,)
5743

58-
# Build mapping of output shapes
59-
# Format: {output_idx: {dim_idx: sympy_expr}}
60-
output_shape_mapping: Dict[int, Dict[int, Any]] = {}
61-
62-
has_symbolic_shapes = False
63-
64-
for out_idx, out_arg in enumerate(output_args):
44+
# Collect shape expressions and dtypes for each output
45+
output_info = []
46+
for out_arg in output_args:
6547
if not hasattr(out_arg, "meta") or "val" not in out_arg.meta:
66-
continue
48+
logger.warning(
49+
"When processing symbolic shapes for TensorRT engine, found no metadata in FX Graph"
50+
)
51+
return None
6752

6853
out_val = out_arg.meta["val"]
69-
if not hasattr(out_val, "shape"):
70-
continue
71-
72-
dim_mapping = {}
73-
for dim_idx, dim_size in enumerate(out_val.shape):
74-
dim_mapping[dim_idx] = dim_size.node.expr
75-
76-
output_shape_mapping[out_idx] = dim_mapping
77-
78-
# Serialize the mapping and base64 encode it
79-
# Note: We can pickle sympy expressions but not SymInt objects directly
80-
# Base64 encoding is needed because the pickled data is binary and needs to be stored as a C++ std::string
81-
try:
82-
pickled = pickle.dumps(output_shape_mapping)
83-
encoded = base64.b64encode(pickled).decode("utf-8")
84-
return encoded
85-
except Exception as e:
86-
import logging
87-
88-
logger = logging.getLogger(__name__)
89-
logger.warning(f"Failed to serialize symbolic shape expressions: {e}")
90-
return None
91-
92-
93-
def deserialize_symbolic_shape_expressions(
94-
serialized: str,
95-
) -> Optional[Dict[int, Dict[int, Any]]]:
96-
"""
97-
Deserialize symbolic shape expressions.
98-
99-
Args:
100-
serialized: Base64-encoded pickled mapping from extract_symbolic_shape_expressions
101-
102-
Returns:
103-
Dictionary mapping {output_idx: {dim_idx: sympy_expr}}
104-
"""
105-
try:
106-
decoded = base64.b64decode(serialized.encode("utf-8"))
107-
return pickle.loads(decoded)
108-
except Exception as e:
109-
import logging
110-
111-
logger = logging.getLogger(__name__)
112-
logger.warning(f"Failed to deserialize symbolic shape expressions: {e}")
113-
return None
54+
if not isinstance(out_val, torch.Tensor):
55+
logger.warning(
56+
"When processing symbolic shapes for TensorRT engine, output is not a tensor"
57+
)
58+
return None
59+
60+
# Extract shape as sympy expressions (can be pickled)
61+
shape_exprs = []
62+
for dim_size in out_val.shape:
63+
if isinstance(dim_size, torch.SymInt):
64+
# Store the sympy expression, which can be pickled
65+
shape_exprs.append(dim_size.node.expr)
66+
else:
67+
# Store concrete integer
68+
shape_exprs.append(int(dim_size))
69+
70+
output_info.append(
71+
{
72+
"shape_exprs": shape_exprs,
73+
"dtype": out_val.dtype,
74+
}
75+
)
76+
77+
return output_info if output_info else None

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from contextlib import nullcontext
55
from typing import Any, Dict, List, Optional, Sequence, Tuple
66

7-
import tensorrt as trt
87
import torch
98
import torch_tensorrt
109
from torch.nn import Module
@@ -22,6 +21,8 @@
2221
multi_gpu_device_check,
2322
)
2423

24+
import tensorrt as trt
25+
2526
logger = logging.getLogger(__name__)
2627

2728

@@ -131,6 +132,7 @@ def __init__(
131132
settings: CompilationSettings = CompilationSettings(),
132133
weight_name_map: Optional[dict[Any, Any]] = None,
133134
requires_output_allocator: bool = False,
135+
symbolic_shape_expressions: Optional[List[Dict[str, Any]]] = None,
134136
_debugger_config: Optional[DebuggerConfig] = None,
135137
):
136138
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
@@ -146,6 +148,7 @@ def __init__(
146148
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
147149
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
148150
requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
151+
symbolic_shape_expressions (List[str]): List of symbolic shape expressions for each output binding
149152
150153
Example:
151154
@@ -222,6 +225,7 @@ def __init__(
222225
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
223226
# If the output tensor is not owned by the engine (output_tensors_are_unowned=True), we need to create a new output tensor in each forward pass
224227
self.output_tensors_are_unowned = False
228+
self.symbolic_shape_expressions = symbolic_shape_expressions
225229
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
226230
self.setup_engine()
227231

@@ -462,7 +466,6 @@ def create_output_allocator(self) -> None:
462466
self.output_allocator = DynamicOutputAllocator(output_dtypes_dict)
463467

464468
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
465-
466469
def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
467470
shape_changed = self.validate_input_shapes(contiguous_inputs)
468471
(

0 commit comments

Comments
 (0)