Skip to content
13 changes: 11 additions & 2 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from torch_tensorrt.dynamo.utils import (
deallocate_module,
get_cpu_memory_usage,
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
Expand Down Expand Up @@ -675,7 +676,7 @@ def compile(
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
}

logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
exported_program = pre_export_lowering(exported_program, settings)
Expand All @@ -689,14 +690,17 @@ def compile(

# Apply lowering on the graph module
gm = post_lowering(gm, settings)
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
logger.debug("Lowered Input graph: " + str(gm.graph))

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(gm, delete_module=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arent these the same?

deallocate_module(exported_program.module(), delete_module=False)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
)
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
else:
remaining_memory, total_memory = torch.cuda.mem_get_info()
if remaining_memory < total_memory // 2:
Expand Down Expand Up @@ -858,6 +862,11 @@ def preserve_module_specs(
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
# This is done to release CPU memory.
for attr in dir(gm):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets make this opt in similar to malloc trim

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be cleared no matter what?

if attr.startswith("_frozen_param"):
delattr(gm, attr)
for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
Expand Down Expand Up @@ -1231,7 +1240,7 @@ def convert_exported_program_to_serialized_trt_engine(

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}

Expand Down
40 changes: 22 additions & 18 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
from torch_tensorrt.dynamo.observer import Observer
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
from torch_tensorrt.dynamo.utils import (
DYNAMIC_DIM,
deallocate_module,
get_cpu_memory_usage,
to_torch_device,
)
from torch_tensorrt.logging import TRT_LOGGER

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand All @@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError):


class TRTInterpreterResult(NamedTuple):
serialized_engine: bytes
engine: trt.ICudaEngine | bytes
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
Expand Down Expand Up @@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None:
_LOGGER.info("Building weight name mapping...")
# Stage 1: Name mapping
torch_device = to_torch_device(self.compilation_settings.device)
self.module.to(torch_device)
sd = self.module.state_dict()
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
weight_name_map: dict[str, Any] = {}
weight_refit_map = self.ctx.weight_refit_map
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
Expand Down Expand Up @@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None:
torch.cuda.empty_cache()

@needs_refit # type: ignore[misc]
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
def _insert_engine_to_cache(self, hash_val: str, engine: trt.ICudaEngine) -> None:
serialized_engine = engine.serialize()
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
# if not self.compilation_settings.strip_engine_weights:
# # set EXCLUDE_WEIGHTS flag to strip weights
# runtime = trt.Runtime(TRT_LOGGER)
# engine = runtime.deserialize_cuda_engine(serialized_engine)

# serialization_config = engine.create_serialization_config()
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
# serialized_engine = engine.serialize_with_config(
Expand Down Expand Up @@ -733,6 +735,9 @@ def run(
return interpreter_result # type: ignore[no-any-return]

self._construct_trt_network_def()
_LOGGER.info(
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
)

if not self.compilation_settings.immutable_weights:
self._save_weight_mapping()
Expand All @@ -750,16 +755,19 @@ def run(
self._create_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)
serialized_engine = self.builder.build_serialized_network(

cuda_engine = self.builder.build_engine_with_config(
self.ctx.net, builder_config
)
assert serialized_engine
assert cuda_engine

_LOGGER.debug(
f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB"
)

_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

self.ctx.clear_cpu_weights_reference_holder()

self._save_timing_cache(
Expand All @@ -772,14 +780,10 @@ def run(
and self.compilation_settings.cache_built_engines
and self.engine_cache is not None
):
self._insert_engine_to_cache(hash_val, serialized_engine)

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()
self._insert_engine_to_cache(hash_val, cuda_engine)

return TRTInterpreterResult(
engine_str,
cuda_engine,
self._input_names,
self._output_names,
self.weight_name_map,
Expand Down
34 changes: 31 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
import logging
from typing import Any, List, Optional, Sequence

Expand All @@ -14,7 +15,11 @@
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_output_dtypes
from torch_tensorrt.dynamo.utils import (
get_cpu_memory_usage,
get_output_dtypes,
release_memory,
)

logger = logging.getLogger(__name__)

Expand All @@ -29,7 +34,7 @@ def infer_module_output_dtypes(
"""
outputs = [node for node in module.graph.nodes if node.op == "output"]
outputs = outputs[0].args
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]
return get_output_dtypes(outputs, truncate_double)


def interpret_module_to_result(
Expand Down Expand Up @@ -65,6 +70,29 @@ def interpret_module_to_result(
)

interpreter_result = interpreter.run()
# Delete the frozen parameters from the module to release CPU memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we gate this by the same env variable as the malloc_trim?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like
export low_RAM_mode=1 python flux.py or build env variable?

del interpreter
for attr in dir(module):
if attr.startswith("_frozen_param"):
delattr(module, attr)
release_memory()
logger.debug(
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
)

serialized_engine = interpreter_result.engine.serialize()
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
serialized_engine = engine_bytes.getvalue()

interpreter_result = TRTInterpreterResult(
engine=serialized_engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
weight_name_map=interpreter_result.weight_name_map,
requires_output_allocator=interpreter_result.requires_output_allocator,
)

return interpreter_result


Expand Down Expand Up @@ -104,7 +132,7 @@ def convert_module(
)

return rt_cls(
serialized_engine=interpreter_result.serialized_engine,
serialized_engine=interpreter_result.engine,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
name=name,
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/debug/_Debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]:
"class": "logging.FileHandler",
"filename": f"{self.cfg.logging_dir}/torch_tensorrt_logging.log",
"formatter": "standard",
"mode": "w", # This will clear the previous content
}
config["loggers"][""]["handlers"].append("file")
return config
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def constant_fold(
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
for node, constant in cf.node_replacements.items():
replace_node_with_constant(
gm, node, torch.nn.Parameter(constant, requires_grad=False)
gm,
node,
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
)

erased_params = []
Expand Down
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import ctypes
import gc
import logging
import platform
import warnings
from dataclasses import fields, replace
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import psutil
import sympy
import tensorrt as trt
import torch
Expand Down Expand Up @@ -858,3 +861,24 @@ def is_thor() -> bool:
if torch.cuda.get_device_capability() in [(11, 0)]:
return True
return False


def get_cpu_memory_usage() -> Any:
return psutil.Process().memory_info().rss / 1024 / 1024


def release_memory() -> None:
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()

if platform.system() == "Linux":
try:
libc = ctypes.CDLL("libc.so.6")
if libc.malloc_trim(0) != 1:
logger.warning("Failed to release CPU memory.")
except Exception:
logger.warning("Failed to release CPU memory.")
46 changes: 46 additions & 0 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,52 @@ def test_resnet18(ir):
torch._dynamo.reset()


def compile_one(idx: int, ir: str):
model = models.resnet18(pretrained=True).eval().to("cuda")
input = torch.randn((idx + 1, 3, 224, 224)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"cache_built_engines": False,
"reuse_cached_engines": False,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"In multiprocess compilation test, process {idx} failed: Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
@unittest.skipIf(
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
def test_resnet18_multiprocess(ir):
import torch.multiprocessing as mp

mp.set_start_method("spawn", force=True)
procs = []
for i in range(3):
p = mp.Process(target=compile_one, args=(i, ir))
p.start()
procs.append(p)
for p in procs:
p.join()
torch._dynamo.reset()


@pytest.mark.unit
@unittest.skipIf(
not importlib.util.find_spec("torchvision"),
Expand Down
Loading