-
Notifications
You must be signed in to change notification settings - Fork 370
Cpu memory optimization #3845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: moe-support
Are you sure you want to change the base?
Cpu memory optimization #3845
Changes from all commits
2540824
c7f8b12
711446c
35d5861
503f320
6b1950c
1e2e669
33ca588
d99f183
66b40bd
92775f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ | |
) | ||
from torch_tensorrt.dynamo.utils import ( | ||
deallocate_module, | ||
get_cpu_memory_usage, | ||
get_flat_args_with_check, | ||
get_output_metadata, | ||
parse_graph_io, | ||
|
@@ -675,7 +676,7 @@ def compile( | |
"l2_limit_for_tiling": l2_limit_for_tiling, | ||
"offload_module_to_cpu": offload_module_to_cpu, | ||
} | ||
|
||
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") | ||
settings = CompilationSettings(**compilation_options) | ||
logger.info("Compilation Settings: %s\n", settings) | ||
exported_program = pre_export_lowering(exported_program, settings) | ||
|
@@ -689,14 +690,17 @@ def compile( | |
|
||
# Apply lowering on the graph module | ||
gm = post_lowering(gm, settings) | ||
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB") | ||
logger.debug("Lowered Input graph: " + str(gm.graph)) | ||
|
||
# Move the weights in the state_dict to CPU | ||
if offload_module_to_cpu: | ||
deallocate_module(gm, delete_module=False) | ||
deallocate_module(exported_program.module(), delete_module=False) | ||
logger.info( | ||
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False" | ||
) | ||
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB") | ||
else: | ||
remaining_memory, total_memory = torch.cuda.mem_get_info() | ||
if remaining_memory < total_memory // 2: | ||
|
@@ -858,6 +862,11 @@ def preserve_module_specs( | |
# Iterate over all components that can be accelerated | ||
# Generate the corresponding TRT Module for those | ||
|
||
# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function. | ||
# This is done to release CPU memory. | ||
for attr in dir(gm): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets make this opt in similar to malloc trim There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be cleared no matter what? |
||
if attr.startswith("_frozen_param"): | ||
delattr(gm, attr) | ||
for name, _ in partitioned_module.named_children(): | ||
submodule = getattr(partitioned_module, name) | ||
# filter on the GraphModule | ||
|
@@ -1231,7 +1240,7 @@ def convert_exported_program_to_serialized_trt_engine( | |
|
||
# Prepare torch_trt inputs | ||
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) | ||
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) | ||
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs) | ||
device = to_torch_tensorrt_device(device) | ||
enabled_precisions = {dtype._from(p) for p in enabled_precisions} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
import io | ||
import logging | ||
from typing import Any, List, Optional, Sequence | ||
|
||
|
@@ -14,7 +15,11 @@ | |
TRTInterpreterResult, | ||
) | ||
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule | ||
from torch_tensorrt.dynamo.utils import get_output_dtypes | ||
from torch_tensorrt.dynamo.utils import ( | ||
get_cpu_memory_usage, | ||
get_output_dtypes, | ||
release_memory, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -29,7 +34,7 @@ def infer_module_output_dtypes( | |
""" | ||
outputs = [node for node in module.graph.nodes if node.op == "output"] | ||
outputs = outputs[0].args | ||
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] | ||
return get_output_dtypes(outputs, truncate_double) | ||
|
||
|
||
def interpret_module_to_result( | ||
|
@@ -65,6 +70,29 @@ def interpret_module_to_result( | |
) | ||
|
||
interpreter_result = interpreter.run() | ||
# Delete the frozen parameters from the module to release CPU memory | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we gate this by the same env variable as the malloc_trim? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say something like |
||
del interpreter | ||
for attr in dir(module): | ||
if attr.startswith("_frozen_param"): | ||
delattr(module, attr) | ||
release_memory() | ||
logger.debug( | ||
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" | ||
) | ||
|
||
serialized_engine = interpreter_result.engine.serialize() | ||
with io.BytesIO() as engine_bytes: | ||
engine_bytes.write(serialized_engine) | ||
serialized_engine = engine_bytes.getvalue() | ||
|
||
interpreter_result = TRTInterpreterResult( | ||
engine=serialized_engine, | ||
input_names=interpreter_result.input_names, | ||
output_names=interpreter_result.output_names, | ||
weight_name_map=interpreter_result.weight_name_map, | ||
requires_output_allocator=interpreter_result.requires_output_allocator, | ||
) | ||
|
||
return interpreter_result | ||
|
||
|
||
|
@@ -104,7 +132,7 @@ def convert_module( | |
) | ||
|
||
return rt_cls( | ||
serialized_engine=interpreter_result.serialized_engine, | ||
serialized_engine=interpreter_result.engine, | ||
cehongwang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
input_binding_names=list(interpreter_result.input_names), | ||
output_binding_names=list(interpreter_result.output_names), | ||
name=name, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arent these the same?