diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 0dc4654db0..ada5cdab19 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -42,6 +42,7 @@ ) from torch_tensorrt.dynamo.utils import ( deallocate_module, + get_cpu_memory_usage, get_flat_args_with_check, get_output_metadata, parse_graph_io, @@ -675,7 +676,7 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, } - + logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -689,14 +690,17 @@ def compile( # Apply lowering on the graph module gm = post_lowering(gm, settings) + logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") logger.debug("Lowered Input graph: " + str(gm.graph)) # Move the weights in the state_dict to CPU if offload_module_to_cpu: + deallocate_module(gm, delete_module=False) deallocate_module(exported_program.module(), delete_module=False) logger.info( "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" ) + logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB") else: remaining_memory, total_memory = torch.cuda.mem_get_info() if remaining_memory < total_memory // 2: @@ -858,6 +862,11 @@ def preserve_module_specs( # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those + # Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. + # This is done to release CPU memory. + for attr in dir(gm): + if attr.startswith("_frozen_param"): + delattr(gm, attr) for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -1231,7 +1240,7 @@ def convert_exported_program_to_serialized_trt_engine( # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) - trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) + trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 73af09448e..c42f65a7f1 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -50,7 +50,12 @@ from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger from torch_tensorrt.dynamo.observer import Observer -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device +from torch_tensorrt.dynamo.utils import ( + DYNAMIC_DIM, + deallocate_module, + get_cpu_memory_usage, + to_torch_device, +) from torch_tensorrt.logging import TRT_LOGGER _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError): class TRTInterpreterResult(NamedTuple): - serialized_engine: bytes + engine: trt.ICudaEngine | bytes input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] @@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - self.module.to(torch_device) - sd = self.module.state_dict() + sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} weight_refit_map = self.ctx.weight_refit_map constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} @@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None: torch.cuda.empty_cache() @needs_refit # type: ignore[misc] - def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: + def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None: + serialized_engine = engine.serialize() # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: # # set EXCLUDE_WEIGHTS flag to strip weights - # runtime = trt.Runtime(TRT_LOGGER) - # engine = runtime.deserialize_cuda_engine(serialized_engine) - # serialization_config = engine.create_serialization_config() # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) # serialized_engine = engine.serialize_with_config( @@ -733,6 +735,9 @@ def run( return interpreter_result # type: ignore[no-any-return] self._construct_trt_network_def() + _LOGGER.info( + f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB" + ) if not self.compilation_settings.immutable_weights: self._save_weight_mapping() @@ -750,16 +755,19 @@ def run( self._create_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) - serialized_engine = self.builder.build_serialized_network( + + cuda_engine = self.builder.build_engine_with_config( self.ctx.net, builder_config ) - assert serialized_engine + assert cuda_engine + + _LOGGER.debug( + f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB" + ) _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) - _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") - self.ctx.clear_cpu_weights_reference_holder() self._save_timing_cache( @@ -772,14 +780,10 @@ def run( and self.compilation_settings.cache_built_engines and self.engine_cache is not None ): - self._insert_engine_to_cache(hash_val, serialized_engine) - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() + self._insert_engine_to_cache(hash_val, cuda_engine) return TRTInterpreterResult( - engine_str, + cuda_engine, self._input_names, self._output_names, self.weight_name_map, diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 35b6c26617..c446e56a99 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import logging from typing import Any, List, Optional, Sequence @@ -14,7 +15,11 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_output_dtypes +from torch_tensorrt.dynamo.utils import ( + get_cpu_memory_usage, + get_output_dtypes, + release_memory, +) logger = logging.getLogger(__name__) @@ -29,7 +34,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] + return get_output_dtypes(outputs, truncate_double) def interpret_module_to_result( @@ -65,6 +70,29 @@ def interpret_module_to_result( ) interpreter_result = interpreter.run() + # Delete the frozen parameters from the module to release CPU memory + del interpreter + for attr in dir(module): + if attr.startswith("_frozen_param"): + delattr(module, attr) + release_memory() + logger.debug( + f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" + ) + + serialized_engine = interpreter_result.engine.serialize() + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + serialized_engine = engine_bytes.getvalue() + + interpreter_result = TRTInterpreterResult( + engine=serialized_engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + weight_name_map=interpreter_result.weight_name_map, + requires_output_allocator=interpreter_result.requires_output_allocator, + ) + return interpreter_result @@ -104,7 +132,7 @@ def convert_module( ) return rt_cls( - serialized_engine=interpreter_result.serialized_engine, + serialized_engine=interpreter_result.engine, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name=name, diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index ec624ffc5a..3e0ae9ee59 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -197,6 +197,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]: "class": "logging.FileHandler", "filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log", "formatter": "standard", + "mode": "w", # This will clear the previous content } config["loggers"][""]["handlers"].append("file") return config diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 5ba84b09b0..9b821df906 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -37,7 +37,9 @@ def constant_fold( # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): replace_node_with_constant( - gm, node, torch.nn.Parameter(constant, requires_grad=False) + gm, + node, + torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), ) erased_params = [] diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 564250e5ae..6cfa6394ec 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,13 +1,16 @@ from __future__ import annotations +import ctypes import gc import logging +import platform import warnings from dataclasses import fields, replace from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np +import psutil import sympy import tensorrt as trt import torch @@ -858,3 +861,24 @@ def is_thor() -> bool: if torch.cuda.get_device_capability() in [(11, 0)]: return True return False + + +def get_cpu_memory_usage() -> Any: + return psutil.Process().memory_info().rss / 1024 / 1024 + + +def release_memory() -> None: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + torch.cuda.synchronize() + + if platform.system() == "Linux": + try: + libc = ctypes.CDLL("libc.so.6") + if libc.malloc_trim(0) != 1: + logger.warning("Failed to release CPU memory.") + except Exception: + logger.warning("Failed to release CPU memory.") diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index c52b732c42..13ba856d35 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -54,6 +54,52 @@ def test_resnet18(ir): torch._dynamo.reset() +def compile_one(idx: int, ir: str): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((idx + 1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + input.shape, dtype=torch.float, format=torch.contiguous_format + ) + ], + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "optimization_level": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"In multiprocess compilation test, process {idx} failed: Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + +@pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +def test_resnet18_multiprocess(ir): + import torch.multiprocessing as mp + + mp.set_start_method("spawn", force=True) + procs = [] + for i in range(3): + p = mp.Process(target=compile_one, args=(i, ir)) + p.start() + procs.append(p) + for p in procs: + p.join() + torch._dynamo.reset() + + @pytest.mark.unit @unittest.skipIf( not importlib.util.find_spec("torchvision"),