Skip to content

Commit 35d5861

Browse files
committed
Reduced memory usage of use_python_runtime=True with the new API
1 parent 711446c commit 35d5861

File tree

3 files changed

+62
-35
lines changed

3 files changed

+62
-35
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError):
6565

6666

6767
class TRTInterpreterResult(NamedTuple):
68-
serialized_engine: bytes
68+
engine: trt.ICudaEngine | bytes
6969
input_names: Sequence[str]
7070
output_names: Sequence[str]
7171
weight_name_map: Optional[dict[Any, Any]]
@@ -731,6 +731,10 @@ def run(
731731
if interpreter_result is not None: # hit the cache
732732
return interpreter_result # type: ignore[no-any-return]
733733

734+
import psutil
735+
736+
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
737+
# breakpoint()
734738
self._construct_trt_network_def()
735739

736740
if not self.compilation_settings.immutable_weights:
@@ -749,41 +753,62 @@ def run(
749753
self._create_timing_cache(
750754
builder_config, self.compilation_settings.timing_cache_path
751755
)
752-
serialized_engine = self.builder.build_serialized_network(
756+
import psutil
757+
758+
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
759+
# breakpoint()
760+
761+
cuda_engine = self.builder.build_engine_with_config(
753762
self.ctx.net, builder_config
754763
)
755-
assert serialized_engine
756764

757765
_LOGGER.info(
758766
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
759767
)
760-
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
761-
762768
self.ctx.clear_cpu_weights_reference_holder()
763769

764770
self._save_timing_cache(
765771
builder_config, self.compilation_settings.timing_cache_path
766772
)
767773

768774
# Engine caching only for refittable engines
769-
if (
770-
not self.compilation_settings.immutable_weights
771-
and self.compilation_settings.cache_built_engines
772-
and self.engine_cache is not None
773-
):
774-
self._insert_engine_to_cache(hash_val, serialized_engine)
775-
776-
with io.BytesIO() as engine_bytes:
777-
engine_bytes.write(serialized_engine)
778-
engine_str = engine_bytes.getvalue()
779-
780-
return TRTInterpreterResult(
781-
engine_str,
782-
self._input_names,
783-
self._output_names,
784-
self.weight_name_map,
785-
self.ctx.requires_output_allocator,
786-
)
775+
# if (
776+
# not self.compilation_settings.immutable_weights
777+
# and self.compilation_settings.cache_built_engines
778+
# and self.engine_cache is not None
779+
# ):
780+
# self._insert_engine_to_cache(hash_val, serialized_engine)
781+
782+
print("After build_engine_with_config")
783+
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
784+
# breakpoint()
785+
assert cuda_engine
786+
if self.compilation_settings.use_python_runtime:
787+
return TRTInterpreterResult(
788+
cuda_engine,
789+
self._input_names,
790+
self._output_names,
791+
self.weight_name_map,
792+
self.ctx.requires_output_allocator,
793+
)
794+
else:
795+
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
796+
# breakpoint()
797+
serialized_engine = cuda_engine.serialize()
798+
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
799+
800+
with io.BytesIO() as engine_bytes:
801+
engine_bytes.write(serialized_engine)
802+
engine_str = engine_bytes.getvalue()
803+
print(psutil.Process().memory_info().rss / 1024 / 1024, "MB")
804+
# breakpoint()
805+
return TRTInterpreterResult(
806+
engine_str,
807+
self._input_names,
808+
self._output_names,
809+
self.weight_name_map,
810+
self.ctx.requires_output_allocator,
811+
)
787812

788813
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
789814
self._cur_node_name = get_node_name(n)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,18 @@ def convert_module(
8989
module, inputs, settings, engine_cache=engine_cache
9090
)
9191

92-
rt_cls = PythonTorchTensorRTModule
93-
9492
if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:
9593
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule
9694

97-
rt_cls = TorchTensorRTModule
95+
return TorchTensorRTModule(
96+
serialized_engine=interpreter_result.engine,
97+
input_binding_names=list(interpreter_result.input_names),
98+
output_binding_names=list(interpreter_result.output_names),
99+
name=name,
100+
settings=settings,
101+
weight_name_map=interpreter_result.weight_name_map,
102+
requires_output_allocator=interpreter_result.requires_output_allocator,
103+
)
98104

99105
elif (
100106
not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime
@@ -103,8 +109,8 @@ def convert_module(
103109
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
104110
)
105111

106-
return rt_cls(
107-
serialized_engine=interpreter_result.serialized_engine,
112+
return PythonTorchTensorRTModule(
113+
cuda_engine=interpreter_result.engine,
108114
input_binding_names=list(interpreter_result.input_names),
109115
output_binding_names=list(interpreter_result.output_names),
110116
name=name,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
1616
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
1717
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
18-
from torch_tensorrt.logging import TRT_LOGGER
1918
from torch_tensorrt.runtime._utils import (
2019
_is_switch_required,
2120
_select_rt_device,
@@ -123,7 +122,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
123122

124123
def __init__(
125124
self,
126-
serialized_engine: Optional[bytes] = None,
125+
cuda_engine: trt.ICudaEngine = None,
127126
input_binding_names: Optional[List[str]] = None,
128127
output_binding_names: Optional[List[str]] = None,
129128
*,
@@ -182,7 +181,7 @@ def __init__(
182181
# Unused currently - to be used by Dynamic Shape support implementation
183182
self.memory_pool = None
184183

185-
self.serialized_engine = serialized_engine
184+
self.engine = cuda_engine
186185
self.input_names = (
187186
input_binding_names if input_binding_names is not None else []
188187
)
@@ -204,7 +203,6 @@ def __init__(
204203
else False
205204
)
206205
self.settings = settings
207-
self.engine = None
208206
self.weight_name_map = weight_name_map
209207
self.target_platform = Platform.current_platform()
210208
self.runtime_states = TorchTRTRuntimeStates(
@@ -219,7 +217,7 @@ def __init__(
219217
self.output_allocator: Optional[DynamicOutputAllocator] = None
220218
self.use_output_allocator_outputs = False
221219

222-
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
220+
if self.engine is not None and not self.settings.lazy_engine_init:
223221
self.setup_engine()
224222

225223
def get_streamable_device_memory_budget(self) -> Any:
@@ -265,8 +263,6 @@ def setup_engine(self) -> None:
265263
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
266264

267265
self.initialized = True
268-
runtime = trt.Runtime(TRT_LOGGER)
269-
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
270266
if self.settings.enable_weight_streaming:
271267
self.set_default_device_memory_budget()
272268
self.context = self.engine.create_execution_context()

0 commit comments

Comments
 (0)