Skip to content

Commit 1e2e669

Browse files
committed
Cleared 2x+ dangling memory after compilation
1 parent 6b1950c commit 1e2e669

File tree

5 files changed

+50
-6
lines changed

5 files changed

+50
-6
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from torch_tensorrt.dynamo.utils import (
4444
deallocate_module,
45+
get_cpu_memory_usage,
4546
get_flat_args_with_check,
4647
get_output_metadata,
4748
parse_graph_io,
@@ -675,7 +676,7 @@ def compile(
675676
"l2_limit_for_tiling": l2_limit_for_tiling,
676677
"offload_module_to_cpu": offload_module_to_cpu,
677678
}
678-
679+
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
679680
settings = CompilationSettings(**compilation_options)
680681
logger.info("Compilation Settings: %s\n", settings)
681682
exported_program = pre_export_lowering(exported_program, settings)
@@ -689,6 +690,7 @@ def compile(
689690

690691
# Apply lowering on the graph module
691692
gm = post_lowering(gm, settings)
693+
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
692694
logger.debug("Lowered Input graph: " + str(gm.graph))
693695

694696
# Move the weights in the state_dict to CPU
@@ -698,6 +700,7 @@ def compile(
698700
logger.info(
699701
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
700702
)
703+
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
701704
else:
702705
remaining_memory, total_memory = torch.cuda.mem_get_info()
703706
if remaining_memory < total_memory // 2:
@@ -859,6 +862,9 @@ def preserve_module_specs(
859862
# Iterate over all components that can be accelerated
860863
# Generate the corresponding TRT Module for those
861864

865+
# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
866+
# This is done to release CPU memory.
867+
[delattr(gm, attr) for attr in dir(gm) if attr.startswith("_frozen_param")]
862868
for name, _ in partitioned_module.named_children():
863869
submodule = getattr(partitioned_module, name)
864870
# filter on the GraphModule

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@
5050
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
5151
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
5252
from torch_tensorrt.dynamo.observer import Observer
53-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
53+
from torch_tensorrt.dynamo.utils import (
54+
DYNAMIC_DIM,
55+
deallocate_module,
56+
get_cpu_memory_usage,
57+
to_torch_device,
58+
)
5459
from torch_tensorrt.logging import TRT_LOGGER
5560

5661
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -729,7 +734,13 @@ def run(
729734
if interpreter_result is not None: # hit the cache
730735
return interpreter_result # type: ignore[no-any-return]
731736

737+
_LOGGER.debug(
738+
f"CPU memory usage before network construction: {get_cpu_memory_usage()} MB"
739+
)
732740
self._construct_trt_network_def()
741+
_LOGGER.debug(
742+
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
743+
)
733744

734745
if not self.compilation_settings.immutable_weights:
735746
self._save_weight_mapping()
@@ -747,12 +758,16 @@ def run(
747758
self._create_timing_cache(
748759
builder_config, self.compilation_settings.timing_cache_path
749760
)
750-
761+
_LOGGER.debug(
762+
f"CPU memory usage before engine building: {get_cpu_memory_usage()} MB"
763+
)
751764
cuda_engine = self.builder.build_engine_with_config(
752765
self.ctx.net, builder_config
753766
)
754767
assert cuda_engine
755-
768+
_LOGGER.debug(
769+
f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB"
770+
)
756771
_LOGGER.info(
757772
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
758773
)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
TRTInterpreterResult,
1515
)
1616
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
17-
from torch_tensorrt.dynamo.utils import get_output_dtypes
17+
from torch_tensorrt.dynamo.utils import (
18+
get_cpu_memory_usage,
19+
get_output_dtypes,
20+
trim_memory,
21+
)
1822

1923
logger = logging.getLogger(__name__)
2024

@@ -29,7 +33,7 @@ def infer_module_output_dtypes(
2933
"""
3034
outputs = [node for node in module.graph.nodes if node.op == "output"]
3135
outputs = outputs[0].args
32-
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]
36+
return get_output_dtypes(outputs, truncate_double)
3337

3438

3539
def interpret_module_to_result(
@@ -103,6 +107,13 @@ def convert_module(
103107
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
104108
)
105109

110+
# Delete the frozen parameters from the module to release CPU memory
111+
[delattr(module, attr) for attr in dir(module) if attr.startswith("_frozen_param")]
112+
trim_memory()
113+
logger.debug(
114+
f"CPU memory usage after clearing frozen parameters and building memory: {get_cpu_memory_usage()} MB"
115+
)
116+
106117
return rt_cls(
107118
cuda_engine=interpreter_result.engine,
108119
input_binding_names=list(interpreter_result.input_names),

py/torch_tensorrt/dynamo/debug/_Debugger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]:
197197
"class": "logging.FileHandler",
198198
"filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log",
199199
"formatter": "standard",
200+
"mode": "w", # This will clear the previous content
200201
}
201202
config["loggers"][""]["handlers"].append("file")
202203
return config

py/torch_tensorrt/dynamo/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import ctypes
34
import gc
45
import logging
56
import warnings
@@ -8,6 +9,7 @@
89
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
910

1011
import numpy as np
12+
import psutil
1113
import sympy
1214
import tensorrt as trt
1315
import torch
@@ -858,3 +860,12 @@ def is_thor() -> bool:
858860
if torch.cuda.get_device_capability() in [(11, 0)]:
859861
return True
860862
return False
863+
864+
865+
def get_cpu_memory_usage() -> Any:
866+
return psutil.Process().memory_info().rss / 1024 / 1024
867+
868+
869+
def trim_memory() -> Any:
870+
libc = ctypes.CDLL("libc.so.6")
871+
return libc.malloc_trim(0)

0 commit comments

Comments
 (0)