Skip to content

Commit 33ca588

Browse files
committed
Added testcases and try catch
1 parent 1e2e669 commit 33ca588

File tree

4 files changed

+70
-8
lines changed

4 files changed

+70
-8
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,9 @@ def preserve_module_specs(
864864

865865
# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
866866
# This is done to release CPU memory.
867-
[delattr(gm, attr) for attr in dir(gm) if attr.startswith("_frozen_param")]
867+
for attr in dir(gm):
868+
if attr.startswith("_frozen_param"):
869+
delattr(gm, attr)
868870
for name, _ in partitioned_module.named_children():
869871
submodule = getattr(partitioned_module, name)
870872
# filter on the GraphModule
@@ -1238,7 +1240,7 @@ def convert_exported_program_to_serialized_trt_engine(
12381240

12391241
# Prepare torch_trt inputs
12401242
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
1241-
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
1243+
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
12421244
device = to_torch_tensorrt_device(device)
12431245
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
12441246

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch_tensorrt.dynamo.utils import (
1818
get_cpu_memory_usage,
1919
get_output_dtypes,
20-
trim_memory,
20+
release_memory,
2121
)
2222

2323
logger = logging.getLogger(__name__)
@@ -108,8 +108,10 @@ def convert_module(
108108
)
109109

110110
# Delete the frozen parameters from the module to release CPU memory
111-
[delattr(module, attr) for attr in dir(module) if attr.startswith("_frozen_param")]
112-
trim_memory()
111+
for attr in dir(module):
112+
if attr.startswith("_frozen_param"):
113+
delattr(module, attr)
114+
release_memory()
113115
logger.debug(
114116
f"CPU memory usage after clearing frozen parameters and building memory: {get_cpu_memory_usage()} MB"
115117
)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ctypes
44
import gc
55
import logging
6+
import platform
67
import warnings
78
from dataclasses import fields, replace
89
from enum import Enum
@@ -866,6 +867,17 @@ def get_cpu_memory_usage() -> Any:
866867
return psutil.Process().memory_info().rss / 1024 / 1024
867868

868869

869-
def trim_memory() -> Any:
870-
libc = ctypes.CDLL("libc.so.6")
871-
return libc.malloc_trim(0)
870+
def release_memory() -> None:
871+
if torch.cuda.is_available():
872+
torch.cuda.synchronize()
873+
torch.cuda.empty_cache()
874+
torch.cuda.ipc_collect()
875+
torch.cuda.synchronize()
876+
877+
if platform.system() == "Linux":
878+
try:
879+
libc = ctypes.CDLL("libc.so.6")
880+
if libc.malloc_trim(0) != 1:
881+
logger.warning("Failed to release CPU memory.")
882+
except Exception:
883+
logger.warning("Failed to release CPU memory.")

tests/py/dynamo/models/test_models.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,52 @@ def test_resnet18(ir):
5454
torch._dynamo.reset()
5555

5656

57+
def compile_one(idx: int, ir: str):
58+
model = models.resnet18(pretrained=True).eval().to("cuda")
59+
input = torch.randn((idx + 1, 3, 224, 224)).to("cuda")
60+
61+
compile_spec = {
62+
"inputs": [
63+
torchtrt.Input(
64+
input.shape, dtype=torch.float, format=torch.contiguous_format
65+
)
66+
],
67+
"device": torchtrt.Device("cuda:0"),
68+
"enabled_precisions": {torch.float},
69+
"ir": ir,
70+
"pass_through_build_failures": True,
71+
"optimization_level": 1,
72+
"cache_built_engines": False,
73+
"reuse_cached_engines": False,
74+
}
75+
76+
trt_mod = torchtrt.compile(model, **compile_spec)
77+
cos_sim = cosine_similarity(model(input), trt_mod(input))
78+
assertions.assertTrue(
79+
cos_sim > COSINE_THRESHOLD,
80+
msg=f"In multiprocess compilation test, process {idx} failed: Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
81+
)
82+
83+
84+
@pytest.mark.unit
85+
@unittest.skipIf(
86+
not importlib.util.find_spec("torchvision"),
87+
"torchvision is not installed",
88+
)
89+
def test_resnet18_multiprocess(ir):
90+
import torch.multiprocessing as mp
91+
92+
mp.set_start_method("spawn", force=True)
93+
procs = []
94+
for i in range(3):
95+
p = mp.Process(target=compile_one, args=(i, ir))
96+
p.start()
97+
procs.append(p)
98+
for p in procs:
99+
p.join()
100+
torch._dynamo.reset()
101+
102+
57103
@pytest.mark.unit
58104
@unittest.skipIf(
59105
not importlib.util.find_spec("torchvision"),

0 commit comments

Comments
 (0)