Skip to content

Commit c21353b

Browse files
committed
Implement simple memory monitoring thread
1 parent 593b634 commit c21353b

File tree

3 files changed

+332
-0
lines changed

3 files changed

+332
-0
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ scipy>=1.7.0
77
cvxpy>=1.3.0
88
joblib>=1.4.0
99
cloudpickle
10+
psutil
1011
tqdm
1112
matplotlib
1213
typing_extensions

src/pydvl/utils/monitor.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
"""
2+
This module implements a simple memory monitoring utility for the whole application.
3+
4+
With [start_memory_monitoring()][pydvl.utils.monitor.start_memory_monitoring] one can
5+
monitor global memory usage, including the memory of child processes. The monitoring
6+
runs in a separate thread and keeps track of the *maximum** memory usage observed.
7+
8+
Monitoring stops automatically when the process exits or receives common termination
9+
signals (SIGINT, SIGTERM, SIGHUP). It can also be stopped manually by calling
10+
[end_memory_monitoring()][pydvl.utils.monitor.end_memory_monitoring].
11+
12+
When monitoring stops, the maximum memory usage is both logged and returned (in bytes).
13+
14+
!!! note
15+
This is intended to report peak memory usage for the whole application, including
16+
child processes. It is not intended to be used for profiling memory usage of
17+
individual functions or modules. Given that there exist numerous profiling tools,
18+
it probably doesn't make sense to extend this module further.
19+
"""
20+
21+
from __future__ import annotations
22+
23+
import atexit
24+
import logging
25+
import signal
26+
import threading
27+
import time
28+
from collections import defaultdict
29+
from itertools import chain
30+
31+
import psutil
32+
33+
__all__ = [
34+
"end_memory_monitoring",
35+
"log_memory_usage_report",
36+
"start_memory_monitoring",
37+
]
38+
39+
logger = logging.getLogger(__name__)
40+
41+
__state_lock = threading.Lock()
42+
__memory_usage = defaultdict(int) # pid -> bytes
43+
__peak_memory_usage = 0 # (in bytes)
44+
__monitoring_enabled = threading.Event()
45+
__memory_monitor_thread: threading.Thread | None = None
46+
47+
48+
def _memory_monitor_thread() -> threading.Thread | None:
49+
"""Returns the memory monitor thread. Can be None if the monitor was never started.
50+
This is only useful for testing purposes."""
51+
return __memory_monitor_thread
52+
53+
54+
def start_memory_monitoring(auto_stop: bool = True):
55+
"""Starts a memory monitoring thread.
56+
57+
The monitor runs in a separate thread and keeps track of maximum memory usage
58+
observed during the monitoring period.
59+
60+
The monitoring stops by calling
61+
[end_memory_monitoring()][pydvl.utils.monitor.end_memory_monitoring] or, if
62+
`auto_stop` is `True` when the process is terminated or exits.
63+
64+
Args:
65+
auto_stop: If True, the monitoring will stop when the process exits
66+
normally or receives common termination signals (SIGINT, SIGTERM, SIGHUP).
67+
68+
"""
69+
global __memory_usage
70+
global __memory_monitor_thread
71+
global __peak_memory_usage
72+
73+
if __monitoring_enabled.is_set():
74+
logger.warning("Memory monitoring is already running.")
75+
return
76+
77+
with __state_lock:
78+
__memory_usage.clear()
79+
__peak_memory_usage = 0
80+
81+
__monitoring_enabled.set()
82+
__memory_monitor_thread = threading.Thread(
83+
target=memory_monitor_run, args=(psutil.Process().pid,)
84+
)
85+
__memory_monitor_thread.start()
86+
87+
if not auto_stop:
88+
return
89+
90+
atexit.register(end_memory_monitoring)
91+
92+
# Register signal handlers for common termination signals, re-raising the original
93+
# signal to terminate as expected
94+
95+
def signal_handler(signum, frame):
96+
end_memory_monitoring()
97+
signal.signal(signum, signal.SIG_DFL)
98+
signal.raise_signal(signum)
99+
100+
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
101+
signal.signal(signal.SIGTERM, signal_handler) # Termination request
102+
# SIGHUP might not be available on all platforms (e.g., Windows)
103+
if hasattr(signal, "SIGHUP"):
104+
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
105+
106+
107+
def memory_monitor_run(pid: int, interval: float = 0.1):
108+
"""Monitors the memory usage of the process and its children.
109+
110+
This function runs in a separate thread and updates the global variable
111+
`__max_memory_usage` with the maximum memory usage observed during the monitoring
112+
period.
113+
114+
The monitoring stops when the __monitoring_enabled event is cleared, which can be
115+
achieved either by calling
116+
[end_memory_monitoring()][pydvl.utils.monitor.end_memory_monitoring], or when the
117+
process is terminated or exits.
118+
"""
119+
global __memory_usage
120+
global __peak_memory_usage
121+
122+
try:
123+
proc = psutil.Process(pid)
124+
except psutil.NoSuchProcess:
125+
logger.error(f"Process {pid} not found. Monitoring cannot start.")
126+
return
127+
128+
while __monitoring_enabled.is_set():
129+
total_mem = 0
130+
try:
131+
for p in chain([proc], proc.children(recursive=True)):
132+
try:
133+
pid = p.pid
134+
rss = p.memory_info().rss
135+
total_mem += rss
136+
with __state_lock:
137+
__memory_usage[pid] = max(__memory_usage[pid], rss)
138+
except psutil.NoSuchProcess:
139+
continue
140+
except psutil.NoSuchProcess: # Catch invalid proc / proc.children
141+
break
142+
143+
with __state_lock:
144+
__peak_memory_usage = max(__peak_memory_usage, total_mem)
145+
146+
time.sleep(interval)
147+
148+
149+
def end_memory_monitoring(log_level=logging.DEBUG) -> tuple[int, dict[int, int]]:
150+
"""Ends the memory monitoring thread and logs the maximum memory usage.
151+
152+
Args:
153+
log_level: The logging level to use.
154+
155+
Returns:
156+
A tuple with the maximum memory usage observed globally, and for each pid
157+
separately as a dict. The dict will be empty if monitoring is disabled.
158+
"""
159+
global __memory_usage
160+
global __peak_memory_usage
161+
162+
if not __monitoring_enabled.is_set():
163+
return 0, {}
164+
165+
__monitoring_enabled.clear()
166+
__memory_monitor_thread.join()
167+
168+
with __state_lock:
169+
peak_mem = __peak_memory_usage
170+
mem_usage = __memory_usage.copy()
171+
__memory_usage.clear()
172+
__peak_memory_usage = 0
173+
174+
log_memory_usage_report(peak_mem, mem_usage, log_level)
175+
return peak_mem, mem_usage
176+
177+
178+
def log_memory_usage_report(
179+
peak_mem: int, mem_usage: dict[int, int], log_level=logging.DEBUG
180+
):
181+
"""
182+
Generates a nicely tabulated memory usage report and logs it.
183+
184+
Args:
185+
peak_mem: The maximum memory usage observed during the monitoring period.
186+
mem_usage: A dictionary mapping process IDs (pid) to memory usage in bytes.
187+
log_level: The log level used for logging the report.
188+
"""
189+
if not mem_usage:
190+
logger.log(log_level, "No memory usage data available.")
191+
return
192+
193+
headers = ("PID", "Memory (Bytes)", "Memory (MB)")
194+
col_widths = (10, 20, 15)
195+
196+
header_line = (
197+
f"{headers[0]:>{col_widths[0]}} "
198+
f"{headers[1]:>{col_widths[1]}} "
199+
f"{headers[2]:>{col_widths[2]}}"
200+
)
201+
separator = "-" * (sum(col_widths) + 2)
202+
203+
summary = (
204+
f"Memory monitor: {len(mem_usage)} processes monitored. "
205+
f"Peak memory usage: {peak_mem / (2**20):.2f} MB"
206+
)
207+
208+
lines = [header_line, separator, summary]
209+
210+
for pid, bytes_used in sorted(
211+
mem_usage.items(), key=lambda item: item[1], reverse=True
212+
):
213+
mb_used = bytes_used / (1024 * 1024)
214+
line = (
215+
f"{pid:>{col_widths[0]}} "
216+
f"{bytes_used:>{col_widths[1]},} "
217+
f"{mb_used:>{col_widths[2]}.2f}"
218+
)
219+
lines.append(line)
220+
221+
lines.append(separator)
222+
223+
logger.log(log_level, "\n".join(lines))

tests/utils/test_monitor.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import multiprocessing
2+
import time
3+
4+
import psutil
5+
import pytest
6+
7+
from pydvl.utils.monitor import (
8+
__monitoring_enabled,
9+
_memory_monitor_thread,
10+
end_memory_monitoring,
11+
start_memory_monitoring,
12+
)
13+
14+
15+
@pytest.fixture(autouse=True)
16+
def cleanup_monitor():
17+
"""
18+
A fixture to ensure that monitoring is stopped after each test.
19+
"""
20+
yield
21+
if __monitoring_enabled.is_set():
22+
end_memory_monitoring()
23+
24+
25+
def test_double_start(caplog):
26+
start_memory_monitoring(auto_stop=False)
27+
# Attempt a second start; should log a warning.
28+
start_memory_monitoring(auto_stop=False)
29+
assert "already running" in caplog.text
30+
31+
32+
def test_end_without_start():
33+
result = end_memory_monitoring()
34+
assert result == (0, {}), f"Expected (0,{{}}) when not monitoring, got {result}"
35+
36+
37+
def test_thread_cleanup():
38+
start_memory_monitoring(auto_stop=False)
39+
time.sleep(0.2) # Allow some time for the thread to start.
40+
41+
end_memory_monitoring()
42+
time.sleep(0.1) # Wait a bit more to ensure the join has completed.
43+
thread = _memory_monitor_thread()
44+
assert thread is not None and not thread.is_alive(), (
45+
"Monitoring thread should have terminated"
46+
)
47+
48+
49+
def memory_allocating_child(size_mb) -> int:
50+
"""Child process that allocates approximately 10 MB of memory."""
51+
data = bytearray(size_mb * 1024 * 1024)
52+
time.sleep(1) # Ensure the memory monitor has time to sample.
53+
return len(data) # Prevent potential optimization
54+
55+
56+
@pytest.mark.timeout(5)
57+
def test_integration_memory_usage():
58+
baseline = psutil.Process().memory_info().rss
59+
start_memory_monitoring(auto_stop=False)
60+
61+
proc = multiprocessing.Process(target=memory_allocating_child, args=(3,))
62+
proc.start()
63+
proc.join()
64+
65+
peak_mem, mem_usage = end_memory_monitoring()
66+
67+
mem_increase = peak_mem - baseline
68+
threshold = 3 * 1024 * 1024
69+
assert mem_increase >= threshold, (
70+
f"Expected memory increase of at least 3 MB, but got {mem_increase / 1024 / 1024:.2f} MB"
71+
)
72+
73+
total_mem = sum(bytes for _, bytes in mem_usage.items())
74+
assert total_mem >= peak_mem, (
75+
"Expected aggregated memory usage to be greater than peak usage"
76+
)
77+
78+
79+
@pytest.mark.timeout(5)
80+
def test_integration_multiple_children():
81+
baseline = psutil.Process().memory_info().rss
82+
start_memory_monitoring(auto_stop=False)
83+
84+
processes = [
85+
multiprocessing.Process(target=memory_allocating_child, args=(1,)),
86+
multiprocessing.Process(target=memory_allocating_child, args=(3,)),
87+
]
88+
89+
for p in processes:
90+
p.start()
91+
for p in processes:
92+
p.join()
93+
94+
peak_mem, mem_usage = end_memory_monitoring()
95+
96+
mem_increase = peak_mem - baseline
97+
threshold = 4 * 1024 * 1024
98+
assert mem_increase >= threshold, (
99+
f"Expected combined memory increase of at least 4 MB, but got {mem_increase / 1024 / 1024:.2f} MB"
100+
)
101+
assert len(mem_usage) == len(processes) + 1, (
102+
f"Expected memory usage for {len(processes) + 1} processes, but got {len(mem_usage)}"
103+
)
104+
105+
total_mem = sum(bytes for _, bytes in mem_usage.items())
106+
assert total_mem >= peak_mem, (
107+
"Expected aggregated memory usage to be greater than peak usage"
108+
)

0 commit comments

Comments
 (0)