Skip to content

Commit 247a9a7

Browse files
authored
add memory monitor and import auto-scheme on demand (#1049)
1 parent c640c72 commit 247a9a7

File tree

6 files changed

+142
-14
lines changed

6 files changed

+142
-14
lines changed

auto_round/auto_scheme/__init__.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
5+
56
# You may obtain a copy of the License at
67
#
78
# http://www.apache.org/licenses/LICENSE-2.0
@@ -11,12 +12,21 @@
1112
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213
# See the License for the specific language governing permissions and
1314
# limitations under the License.
15+
1416
from auto_round.logger import logger
1517

1618
from auto_round.auto_scheme.gen_auto_scheme import AutoScheme
17-
from auto_round.auto_scheme.register import AUTO_SCHEME_METHODS
1819

19-
try:
20-
import auto_round.auto_scheme.default_alg
21-
except ImportError:
22-
logger.warning("AutoScheme is currently supported only on Linux.")
20+
21+
def __getattr__(name):
22+
if name == "AUTO_SCHEME_METHODS":
23+
try:
24+
import auto_round.auto_scheme.default_alg
25+
except ImportError:
26+
logger.warning("AutoScheme is currently supported only on Linux.")
27+
28+
from auto_round.auto_scheme.register import AUTO_SCHEME_METHODS
29+
30+
return AUTO_SCHEME_METHODS
31+
32+
raise AttributeError(f"auto-scheme has no attribute '{name}'")

auto_round/auto_scheme/gen_auto_scheme.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import torch
1919

20-
from auto_round.auto_scheme.register import AUTO_SCHEME_METHODS
2120
from auto_round.auto_scheme.utils import compute_avg_bits_for_scheme
2221
from auto_round.compressors.utils import gguf_type_fallback
2322
from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG
@@ -103,7 +102,9 @@ def _check_configs(self) -> None:
103102

104103
def get_layer_config(self) -> dict[str, dict]:
105104
method_name = self.auto_scheme.method
106-
method_func = AUTO_SCHEME_METHODS[method_name]
105+
from auto_round import auto_scheme
106+
107+
method_func = auto_scheme.AUTO_SCHEME_METHODS[method_name]
107108
if self.auto_scheme.low_gpu_mem_usage:
108109
self.enable_torch_compile = False
109110

auto_round/compressors/base.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
is_fp8_model,
8888
is_hpex_available,
8989
llm_load_model,
90+
memory_monitor,
9091
mv_module_from_gpu,
9192
normalize_input,
9293
set_amax_for_all_moe_layers,
@@ -1025,6 +1026,7 @@ def quantize_and_save(
10251026
self.save_quantized(save_folder, format=format, inplace=inplace, **kwargs)
10261027

10271028
folders.append(save_folder)
1029+
memory_monitor.log_summary()
10281030

10291031
return model, folders
10301032

@@ -1513,6 +1515,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
15131515
all_to_quantized_module_names.remove(m.tmp_name)
15141516
if not self.immediate_saving:
15151517
mv_module_from_gpu(block)
1518+
memory_monitor.log_summary()
15161519
pbar.update(1)
15171520

15181521
pbar.close()
@@ -1752,6 +1755,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17521755
layer.cpu()
17531756
layer_names.remove(layer_name)
17541757
if len(layer_names) == 0:
1758+
memory_monitor.update()
1759+
memory_monitor.log_summary()
17551760
return
17561761
q_layer_inputs = None
17571762
enable_quanted_input = self.enable_quanted_input
@@ -1770,7 +1775,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17701775
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
17711776
accelerate.hooks.remove_hook_from_submodules(
17721777
self.model
1773-
) ##self.model.hf_device_map has not been changed
1778+
) # self.model.hf_device_map has not been changed
17741779
if not self.immediate_saving:
17751780
self.model = mv_module_from_gpu(self.model)
17761781
clear_memory(device_list=self.device_list)
@@ -1789,13 +1794,14 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17891794
immediate_saving(self, m, name=layer_name, last_group=True)
17901795
del layer_input
17911796
clear_memory(q_layer_input, device_list=self.device_list)
1797+
memory_monitor.log_summary()
17921798

17931799
@torch.no_grad()
17941800
def _get_block_outputs(
17951801
self,
17961802
block: torch.nn.Module,
1797-
input_ids: torch.Tensor,
1798-
input_others: torch.Tensor,
1803+
input_ids: torch.Tensor | list[torch.Tensor],
1804+
input_others: torch.Tensor | dict,
17991805
bs: int,
18001806
device: Union[str, torch.device],
18011807
cache_device: Union[str, torch.device],
@@ -2805,7 +2811,7 @@ def _quantize_block(
28052811
f"quantized {len(quantized_layer_names)}/{(len(quantized_layer_names) + len(unquantized_layer_names))} "
28062812
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
28072813
)
2808-
logger.info(dump_info)
2814+
28092815
if self.low_gpu_mem_usage:
28102816
clear_memory(device_list=self.device_list) # clear cached memory during training
28112817
if len(unquantized_layer_names) != 0:
@@ -2833,6 +2839,8 @@ def _quantize_block(
28332839
mv_module_from_gpu(block)
28342840

28352841
clear_memory(input_ids)
2842+
memory_info_summary = memory_monitor.get_summary()
2843+
logger.infoclean(dump_info + "," + memory_info_summary)
28362844

28372845
return q_outputs, output
28382846
else:
@@ -2841,6 +2849,8 @@ def _quantize_block(
28412849
if auto_offload:
28422850
mv_module_from_gpu(block)
28432851
clear_memory(input_ids)
2852+
memory_info_summary = memory_monitor.get_summary()
2853+
logger.infoclean(dump_info + "," + memory_info_summary)
28442854

28452855
return None, output
28462856

@@ -3174,7 +3184,7 @@ def _sampling_inputs(
31743184
cls,
31753185
input_ids: Union[list[torch.Tensor], dict],
31763186
input_others: dict,
3177-
indices: list[int],
3187+
indices: list[int] | torch.Tensor,
31783188
seqlen: int,
31793189
batch_dim: int = 0,
31803190
share_cache_keys: tuple = (),

auto_round/logger.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,18 @@ def trace(self, message, *args):
5353

5454

5555
# Add the trace method to the Logger class
56+
5657
logging.Logger.trace = trace
58+
INFOCLEAN_LEVEL = 21
59+
logging.addLevelName(INFOCLEAN_LEVEL, "INFOCLEAN")
60+
61+
62+
def infoclean(self, message, *args, **kwargs):
63+
if self.isEnabledFor(INFOCLEAN_LEVEL):
64+
self._log(INFOCLEAN_LEVEL, message, args, **kwargs)
65+
66+
67+
logging.Logger.infoclean = infoclean
5768

5869

5970
class AutoRoundFormatter(logging.Formatter):
@@ -65,10 +76,11 @@ class AutoRoundFormatter(logging.Formatter):
6576
cyan = "\x1b[36;1m"
6677
blue = "\x1b[34;1m"
6778
_format = "%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s"
68-
79+
_format_clean = "%(message)s"
6980
FORMATS = {
7081
logging.DEBUG: blue + _format + reset,
7182
logging.INFO: grey + _format + reset,
83+
INFOCLEAN_LEVEL: grey + _format_clean + reset,
7284
logging.WARNING: yellow + _format + reset,
7385
logging.ERROR: bold_red + _format + reset,
7486
logging.CRITICAL: bold_red + _format + reset,

auto_round/utils/device.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
import re
1717
from functools import lru_cache
1818
from itertools import combinations
19+
from threading import Lock
1920
from typing import Callable, Union
2021

2122
import cpuinfo
23+
import psutil
2224
import torch
2325

2426
from auto_round.logger import logger
@@ -442,6 +444,7 @@ def _clear_memory_for_cpu_and_cuda(
442444
def clear_memory(tensor: torch.Tensor | None | list[torch.Tensor] = None, device_list: list | tuple | None = None):
443445
from auto_round.utils.device import is_hpex_available
444446

447+
memory_monitor.update(device_list=device_list)
445448
if is_hpex_available():
446449
# hpu does not have empty_cache
447450
return
@@ -1308,3 +1311,95 @@ def parse_available_devices(device_map: Union[str, torch.device, int, dict, None
13081311
return sorted(devices)
13091312

13101313
raise TypeError(f"Unsupported device_map type: {type(device_map)}")
1314+
1315+
1316+
class MemoryMonitor:
1317+
"""Global memory monitor for tracking peak RAM and VRAM usage."""
1318+
1319+
_instance = None
1320+
_lock = Lock()
1321+
_initialized = False
1322+
1323+
def __new__(cls):
1324+
if cls._instance is None:
1325+
with cls._lock:
1326+
if cls._instance is None:
1327+
cls._instance = super().__new__(cls)
1328+
cls._instance._initialized = False
1329+
return cls._instance
1330+
1331+
def __init__(self):
1332+
if self._initialized:
1333+
return
1334+
self._initialized = True
1335+
self.peak_ram = 0.0 # GB
1336+
self.peak_vram = {} # {device_id: peak_mb}
1337+
self.enabled = True
1338+
1339+
def update(self, device_list=None):
1340+
"""Update current memory usage and track peaks."""
1341+
if not self.enabled:
1342+
return
1343+
# Track RAM
1344+
process = psutil.Process()
1345+
current_ram = process.memory_info().rss / 1024**3 # GB
1346+
self.peak_ram = max(self.peak_ram, current_ram)
1347+
if device_list is None: # TODO this have issue, wait for clean memory all pass device_list
1348+
device_list = [0]
1349+
if device_list is not None:
1350+
if not isinstance(device_list, (list, tuple)):
1351+
device_list = [device_list]
1352+
else:
1353+
if torch.cuda.is_available():
1354+
device_list = list(range(torch.cuda.device_count()))
1355+
elif torch.xpu.is_available():
1356+
device_list = list(range(torch.xpu.device_count()))
1357+
1358+
for device in device_list:
1359+
if device == "cpu":
1360+
continue
1361+
if torch.cuda.is_available():
1362+
current_vram = torch.cuda.memory_reserved(device) / 1024**3 # GB
1363+
elif torch.xpu.is_available():
1364+
current_vram = torch.xpu.memory_reserved(device) / 1024**3 # GB
1365+
else:
1366+
return
1367+
1368+
device = str(device).split(":")[-1]
1369+
if current_vram > 0:
1370+
if device not in self.peak_vram:
1371+
self.peak_vram[device] = 0.0
1372+
1373+
self.peak_vram[device] = max(self.peak_vram[device], current_vram)
1374+
1375+
def update_cpu(self):
1376+
if not self.enabled:
1377+
return
1378+
process = psutil.Process()
1379+
current_ram = process.memory_info().rss / 1024**3 # GB
1380+
self.peak_ram = max(self.peak_ram, current_ram)
1381+
1382+
def reset(self):
1383+
"""Reset all statistics."""
1384+
self.peak_ram = 0.0
1385+
self.peak_vram = {}
1386+
1387+
def get_summary(self):
1388+
"""Get summary of peak memory usage."""
1389+
summary = f"'peak_ram': {round(self.peak_ram, 2)}GB"
1390+
if len(self.peak_vram) > 0:
1391+
sorted_items = sorted(self.peak_vram.items())
1392+
items_str = ", ".join([f"'{k}': {round(v, 2)}GB" for k, v in sorted_items])
1393+
summary += f", 'peak_vram': {{{items_str}}}"
1394+
return summary
1395+
1396+
def log_summary(self):
1397+
"""Log memory usage summary."""
1398+
summary = self.get_summary()
1399+
logger.info(summary)
1400+
1401+
return summary
1402+
1403+
1404+
# Global singleton instance
1405+
memory_monitor = MemoryMonitor()

docs/step_by_step.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ If not explicitly specify '--task', the default value will be used (typically co
725725
~~~
726726
The last format will be used in evaluation if multiple formats have been exported.
727727
728-
Note: To use the vllm backend, please add `--vllm` into the upper command.
728+
Note: To use the vllm backend, please add `--eval_backend vllm` to the command above.
729729
730730
### Eval the Quantized model
731731

0 commit comments

Comments
 (0)