Skip to content

Commit 9e7ca5d

Browse files
authored
Cpu memory optimization (#3602)
1 parent c286767 commit 9e7ca5d

File tree

6 files changed

+62
-30
lines changed

6 files changed

+62
-30
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ def compile(
693693

694694
# Move the weights in the state_dict to CPU
695695
if offload_module_to_cpu:
696+
deallocate_module(gm, delete_module=False)
696697
deallocate_module(exported_program.module(), delete_module=False)
697698
logger.info(
698699
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError):
6565

6666

6767
class TRTInterpreterResult(NamedTuple):
68-
serialized_engine: bytes
68+
engine: trt.ICudaEngine
6969
input_names: Sequence[str]
7070
output_names: Sequence[str]
7171
weight_name_map: Optional[dict[Any, Any]]
@@ -512,8 +512,7 @@ def _save_weight_mapping(self) -> None:
512512
_LOGGER.info("Building weight name mapping...")
513513
# Stage 1: Name mapping
514514
torch_device = to_torch_device(self.compilation_settings.device)
515-
self.module.to(torch_device)
516-
sd = self.module.state_dict()
515+
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
517516
weight_name_map: dict[str, Any] = {}
518517
weight_refit_map = self.ctx.weight_refit_map
519518
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
@@ -592,13 +591,11 @@ def _save_weight_mapping(self) -> None:
592591
torch.cuda.empty_cache()
593592

594593
@needs_refit # type: ignore[misc]
595-
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
594+
def _insert_engine_to_cache(self, hash_val: str, engine: bytes) -> None:
595+
serialized_engine = engine.serialize()
596596
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
597597
# if not self.compilation_settings.strip_engine_weights:
598598
# # set EXCLUDE_WEIGHTS flag to strip weights
599-
# runtime = trt.Runtime(TRT_LOGGER)
600-
# engine = runtime.deserialize_cuda_engine(serialized_engine)
601-
602599
# serialization_config = engine.create_serialization_config()
603600
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
604601
# serialized_engine = engine.serialize_with_config(
@@ -750,16 +747,15 @@ def run(
750747
self._create_timing_cache(
751748
builder_config, self.compilation_settings.timing_cache_path
752749
)
753-
serialized_engine = self.builder.build_serialized_network(
750+
751+
cuda_engine = self.builder.build_engine_with_config(
754752
self.ctx.net, builder_config
755753
)
756-
assert serialized_engine
754+
assert cuda_engine
757755

758756
_LOGGER.info(
759757
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
760758
)
761-
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
762-
763759
self.ctx.clear_cpu_weights_reference_holder()
764760

765761
self._save_timing_cache(
@@ -772,14 +768,10 @@ def run(
772768
and self.compilation_settings.cache_built_engines
773769
and self.engine_cache is not None
774770
):
775-
self._insert_engine_to_cache(hash_val, serialized_engine)
776-
777-
with io.BytesIO() as engine_bytes:
778-
engine_bytes.write(serialized_engine)
779-
engine_str = engine_bytes.getvalue()
771+
self._insert_engine_to_cache(hash_val, cuda_engine)
780772

781773
return TRTInterpreterResult(
782-
engine_str,
774+
cuda_engine,
783775
self._input_names,
784776
self._output_names,
785777
self.weight_name_map,

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def convert_module(
104104
)
105105

106106
return rt_cls(
107-
serialized_engine=interpreter_result.serialized_engine,
107+
cuda_engine=interpreter_result.engine,
108108
input_binding_names=list(interpreter_result.input_names),
109109
output_binding_names=list(interpreter_result.output_names),
110110
name=name,

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def constant_fold(
3737
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
3838
for node, constant in cf.node_replacements.items():
3939
replace_node_with_constant(
40-
gm, node, torch.nn.Parameter(constant, requires_grad=False)
40+
gm,
41+
node,
42+
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
4143
)
4244

4345
erased_params = []

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
123123

124124
def __init__(
125125
self,
126+
cuda_engine: trt.ICudaEngine = None,
126127
serialized_engine: Optional[bytes] = None,
127128
input_binding_names: Optional[List[str]] = None,
128129
output_binding_names: Optional[List[str]] = None,
@@ -182,7 +183,19 @@ def __init__(
182183
# Unused currently - to be used by Dynamic Shape support implementation
183184
self.memory_pool = None
184185

185-
self.serialized_engine = serialized_engine
186+
if cuda_engine:
187+
assert isinstance(
188+
cuda_engine, trt.ICudaEngine
189+
), "Cuda engine must be a trt.ICudaEngine object"
190+
self.engine = cuda_engine
191+
elif serialized_engine:
192+
assert isinstance(
193+
serialized_engine, bytes
194+
), "Serialized engine must be a bytes object"
195+
self.engine = serialized_engine
196+
else:
197+
raise ValueError("Serialized engine or cuda engine must be provided")
198+
186199
self.input_names = (
187200
input_binding_names if input_binding_names is not None else []
188201
)
@@ -204,7 +217,6 @@ def __init__(
204217
else False
205218
)
206219
self.settings = settings
207-
self.engine = None
208220
self.weight_name_map = weight_name_map
209221
self.target_platform = Platform.current_platform()
210222
self.runtime_states = TorchTRTRuntimeStates(
@@ -219,7 +231,7 @@ def __init__(
219231
self.output_allocator: Optional[DynamicOutputAllocator] = None
220232
self.use_output_allocator_outputs = False
221233

222-
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
234+
if self.engine and not self.settings.lazy_engine_init:
223235
self.setup_engine()
224236

225237
def get_streamable_device_memory_budget(self) -> Any:
@@ -260,13 +272,22 @@ def set_default_device_memory_budget(self) -> int:
260272
return self._set_device_memory_budget(budget_bytes)
261273

262274
def setup_engine(self) -> None:
275+
276+
if isinstance(self.engine, trt.ICudaEngine):
277+
pass
278+
elif isinstance(self.engine, bytes):
279+
runtime = trt.Runtime(TRT_LOGGER)
280+
self.engine = runtime.deserialize_cuda_engine(self.engine)
281+
else:
282+
raise ValueError(
283+
"Expected engine as trt.ICudaEngine or serialized engine as bytes"
284+
)
285+
263286
assert (
264287
self.target_platform == Platform.current_platform()
265288
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
266289

267290
self.initialized = True
268-
runtime = trt.Runtime(TRT_LOGGER)
269-
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
270291
if self.settings.enable_weight_streaming:
271292
self.set_default_device_memory_budget()
272293
self.context = self.engine.create_execution_context()
@@ -302,7 +323,7 @@ def _check_initialized(self) -> None:
302323
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
303324

304325
def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
305-
state_dict[prefix + "engine"] = self.serialized_engine
326+
state_dict[prefix + "engine"] = self.engine
306327
state_dict[prefix + "input_names"] = self.input_names
307328
state_dict[prefix + "output_names"] = self.output_names
308329
state_dict[prefix + "platform"] = self.target_platform
@@ -317,7 +338,7 @@ def _load_from_state_dict(
317338
unexpected_keys: Any,
318339
error_msgs: Any,
319340
) -> None:
320-
self.serialized_engine = state_dict[prefix + "engine"]
341+
self.engine = state_dict[prefix + "engine"]
321342
self.input_names = state_dict[prefix + "input_names"]
322343
self.output_names = state_dict[prefix + "output_names"]
323344
self.target_platform = state_dict[prefix + "platform"]

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
import base64
44
import copy
5+
import io
56
import logging
67
import pickle
78
from typing import Any, List, Optional, Tuple, Union
89

10+
import tensorrt as trt
911
import torch
1012
from torch_tensorrt._Device import Device
1113
from torch_tensorrt._enums import Platform
@@ -76,6 +78,7 @@ class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
7678

7779
def __init__(
7880
self,
81+
cuda_engine: Optional[trt.ICudaEngine | bytes] = None,
7982
serialized_engine: Optional[bytes] = None,
8083
input_binding_names: Optional[List[str]] = None,
8184
output_binding_names: Optional[List[str]] = None,
@@ -123,8 +126,22 @@ def __init__(
123126
"""
124127
super(TorchTensorRTModule, self).__init__()
125128

126-
if not isinstance(serialized_engine, bytearray):
127-
ValueError("Expected serialized engine as bytearray")
129+
if serialized_engine:
130+
assert isinstance(
131+
serialized_engine, bytes
132+
), "Serialized engine must be a bytes object"
133+
self.serialized_engine = serialized_engine
134+
135+
elif cuda_engine:
136+
assert isinstance(
137+
cuda_engine, trt.ICudaEngine
138+
), "Cuda engine must be a trt.ICudaEngine object"
139+
serialized_engine = cuda_engine.serialize()
140+
with io.BytesIO() as engine_bytes:
141+
engine_bytes.write(serialized_engine) # type: ignore
142+
self.serialized_engine = engine_bytes.getvalue()
143+
else:
144+
raise ValueError("Serialized engine or cuda engine must be provided")
128145

129146
self.input_binding_names = (
130147
input_binding_names if input_binding_names is not None else []
@@ -136,12 +153,11 @@ def __init__(
136153
self.hardware_compatible = settings.hardware_compatible
137154
self.settings = copy.deepcopy(settings)
138155
self.weight_name_map = weight_name_map
139-
self.serialized_engine = serialized_engine
140156
self.engine = None
141157
self.requires_output_allocator = requires_output_allocator
142158

143159
if (
144-
serialized_engine
160+
self.serialized_engine
145161
and not self.settings.lazy_engine_init
146162
and not self.settings.enable_cross_compile_for_windows
147163
):

0 commit comments

Comments
 (0)