Skip to content

Commit 6b1950c

Browse files
committed
Revised according to comments
1 parent 503f320 commit 6b1950c

File tree

4 files changed

+61
-42
lines changed

4 files changed

+61
-42
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError):
6565

6666

6767
class TRTInterpreterResult(NamedTuple):
68-
engine: trt.ICudaEngine | bytes
68+
engine: trt.ICudaEngine
6969
input_names: Sequence[str]
7070
output_names: Sequence[str]
7171
weight_name_map: Optional[dict[Any, Any]]
@@ -770,29 +770,13 @@ def run(
770770
):
771771
self._insert_engine_to_cache(hash_val, cuda_engine)
772772

773-
if self.compilation_settings.use_python_runtime:
774-
return TRTInterpreterResult(
775-
cuda_engine,
776-
self._input_names,
777-
self._output_names,
778-
self.weight_name_map,
779-
self.ctx.requires_output_allocator,
780-
)
781-
else:
782-
serialized_engine = cuda_engine.serialize()
783-
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
784-
785-
with io.BytesIO() as engine_bytes:
786-
engine_bytes.write(serialized_engine)
787-
engine_str = engine_bytes.getvalue()
788-
789-
return TRTInterpreterResult(
790-
engine_str,
791-
self._input_names,
792-
self._output_names,
793-
self.weight_name_map,
794-
self.ctx.requires_output_allocator,
795-
)
773+
return TRTInterpreterResult(
774+
cuda_engine,
775+
self._input_names,
776+
self._output_names,
777+
self.weight_name_map,
778+
self.ctx.requires_output_allocator,
779+
)
796780

797781
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
798782
self._cur_node_name = get_node_name(n)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,12 @@ def convert_module(
8989
module, inputs, settings, engine_cache=engine_cache
9090
)
9191

92+
rt_cls = PythonTorchTensorRTModule
93+
9294
if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:
9395
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule
9496

95-
return TorchTensorRTModule(
96-
serialized_engine=interpreter_result.engine,
97-
input_binding_names=list(interpreter_result.input_names),
98-
output_binding_names=list(interpreter_result.output_names),
99-
name=name,
100-
settings=settings,
101-
weight_name_map=interpreter_result.weight_name_map,
102-
requires_output_allocator=interpreter_result.requires_output_allocator,
103-
)
97+
rt_cls = TorchTensorRTModule
10498

10599
elif (
106100
not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime
@@ -109,7 +103,7 @@ def convert_module(
109103
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
110104
)
111105

112-
return PythonTorchTensorRTModule(
106+
return rt_cls(
113107
cuda_engine=interpreter_result.engine,
114108
input_binding_names=list(interpreter_result.input_names),
115109
output_binding_names=list(interpreter_result.output_names),

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
1616
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
1717
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
18+
from torch_tensorrt.logging import TRT_LOGGER
1819
from torch_tensorrt.runtime._utils import (
1920
_is_switch_required,
2021
_select_rt_device,
@@ -123,6 +124,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
123124
def __init__(
124125
self,
125126
cuda_engine: trt.ICudaEngine = None,
127+
serialized_engine: Optional[bytes] = None,
126128
input_binding_names: Optional[List[str]] = None,
127129
output_binding_names: Optional[List[str]] = None,
128130
*,
@@ -181,7 +183,19 @@ def __init__(
181183
# Unused currently - to be used by Dynamic Shape support implementation
182184
self.memory_pool = None
183185

184-
self.engine = cuda_engine
186+
if cuda_engine:
187+
assert isinstance(
188+
cuda_engine, trt.ICudaEngine
189+
), "Cuda engine must be a trt.ICudaEngine object"
190+
self.engine = cuda_engine
191+
elif serialized_engine:
192+
assert isinstance(
193+
serialized_engine, bytes
194+
), "Serialized engine must be a bytes object"
195+
self.engine = serialized_engine
196+
else:
197+
raise ValueError("Serialized engine or cuda engine must be provided")
198+
185199
self.input_names = (
186200
input_binding_names if input_binding_names is not None else []
187201
)
@@ -217,7 +231,7 @@ def __init__(
217231
self.output_allocator: Optional[DynamicOutputAllocator] = None
218232
self.use_output_allocator_outputs = False
219233

220-
if self.engine is not None and not self.settings.lazy_engine_init:
234+
if self.engine and not self.settings.lazy_engine_init:
221235
self.setup_engine()
222236

223237
def get_streamable_device_memory_budget(self) -> Any:
@@ -258,6 +272,17 @@ def set_default_device_memory_budget(self) -> int:
258272
return self._set_device_memory_budget(budget_bytes)
259273

260274
def setup_engine(self) -> None:
275+
276+
if isinstance(self.engine, trt.ICudaEngine):
277+
pass
278+
elif isinstance(self.engine, bytes):
279+
runtime = trt.Runtime(TRT_LOGGER)
280+
self.engine = runtime.deserialize_cuda_engine(self.engine)
281+
else:
282+
raise ValueError(
283+
"Expected engine as trt.ICudaEngine or serialized engine as bytes"
284+
)
285+
261286
assert (
262287
self.target_platform == Platform.current_platform()
263288
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
@@ -298,7 +323,7 @@ def _check_initialized(self) -> None:
298323
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
299324

300325
def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
301-
state_dict[prefix + "engine"] = self.serialized_engine
326+
state_dict[prefix + "engine"] = self.engine
302327
state_dict[prefix + "input_names"] = self.input_names
303328
state_dict[prefix + "output_names"] = self.output_names
304329
state_dict[prefix + "platform"] = self.target_platform
@@ -313,7 +338,7 @@ def _load_from_state_dict(
313338
unexpected_keys: Any,
314339
error_msgs: Any,
315340
) -> None:
316-
self.serialized_engine = state_dict[prefix + "engine"]
341+
self.engine = state_dict[prefix + "engine"]
317342
self.input_names = state_dict[prefix + "input_names"]
318343
self.output_names = state_dict[prefix + "output_names"]
319344
self.target_platform = state_dict[prefix + "platform"]

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
import base64
44
import copy
5+
import io
56
import logging
67
import pickle
78
from typing import Any, List, Optional, Tuple, Union
89

10+
import tensorrt as trt
911
import torch
1012
from torch_tensorrt._Device import Device
1113
from torch_tensorrt._enums import Platform
@@ -76,6 +78,7 @@ class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
7678

7779
def __init__(
7880
self,
81+
cuda_engine: Optional[trt.ICudaEngine | bytes] = None,
7982
serialized_engine: Optional[bytes] = None,
8083
input_binding_names: Optional[List[str]] = None,
8184
output_binding_names: Optional[List[str]] = None,
@@ -123,8 +126,22 @@ def __init__(
123126
"""
124127
super(TorchTensorRTModule, self).__init__()
125128

126-
if not isinstance(serialized_engine, bytearray):
127-
ValueError("Expected serialized engine as bytearray")
129+
if serialized_engine:
130+
assert isinstance(
131+
serialized_engine, bytes
132+
), "Serialized engine must be a bytes object"
133+
self.serialized_engine = serialized_engine
134+
135+
elif cuda_engine:
136+
assert isinstance(
137+
cuda_engine, trt.ICudaEngine
138+
), "Cuda engine must be a trt.ICudaEngine object"
139+
serialized_engine = cuda_engine.serialize()
140+
with io.BytesIO() as engine_bytes:
141+
engine_bytes.write(serialized_engine) # type: ignore
142+
self.serialized_engine = engine_bytes.getvalue()
143+
else:
144+
raise ValueError("Serialized engine or cuda engine must be provided")
128145

129146
self.input_binding_names = (
130147
input_binding_names if input_binding_names is not None else []
@@ -136,12 +153,11 @@ def __init__(
136153
self.hardware_compatible = settings.hardware_compatible
137154
self.settings = copy.deepcopy(settings)
138155
self.weight_name_map = weight_name_map
139-
self.serialized_engine = serialized_engine
140156
self.engine = None
141157
self.requires_output_allocator = requires_output_allocator
142158

143159
if (
144-
serialized_engine
160+
self.serialized_engine
145161
and not self.settings.lazy_engine_init
146162
and not self.settings.enable_cross_compile_for_windows
147163
):

0 commit comments

Comments
 (0)