Skip to content

Commit ef0d5eb

Browse files
angelala3252KrishPatel13abrichr
authored
feat(record): memory profiling
* tracemalloc * pympler * todo * changed position of tracemalloc stats collection * updated requirements.txt * memory leak fix and cleanup * removed todo * changed printing to logging * alphabetical order * changes to tracemalloc usage * plot memory usage * memory writer terminates with performance writer * add MemoryStat table to database * remove todo * switch from writing/reading memory using file to saving/retrieving from database * add memory legend to performance plot * prevent error from child processes terminating * style changes * moved PLOT_PERFORMANCE to config.py * only display memory legend if there is memory data * moved memory logging into function * removed unnecessary call to row2dicts * rename memory_usage to memory_usage_bytes * replaced alembic revision * remove start_time_deltas; minor refactor * fix indent --------- Co-authored-by: Krish Patel <65433817+KrishPatel13@users.noreply.github.com> Co-authored-by: Richard Abrich <richard.abrich@mldsai.com> Co-authored-by: Richard Abrich <richard.abrich@gmail.com>
1 parent 2bb8814 commit ef0d5eb

File tree

7 files changed

+216
-56
lines changed

7 files changed

+216
-56
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""add MemoryStat
2+
3+
Revision ID: 607d1380b5ae
4+
Revises: 104d4a614d95
5+
Create Date: 2023-06-28 11:54:36.749072
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
from openadapt.models import ForceFloat
11+
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '607d1380b5ae'
15+
down_revision = '104d4a614d95'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.create_table('memory_stat',
23+
sa.Column('id', sa.Integer(), nullable=False),
24+
sa.Column('recording_timestamp', sa.Integer(), nullable=True),
25+
sa.Column('memory_usage_bytes', ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True),
26+
sa.Column('timestamp', ForceFloat(precision=10, scale=2, asdecimal=False), nullable=True),
27+
sa.PrimaryKeyConstraint('id', name=op.f('pk_memory_stat'))
28+
)
29+
# ### end Alembic commands ###
30+
31+
32+
def downgrade() -> None:
33+
# ### commands auto generated by Alembic - please adjust! ###
34+
op.drop_table('memory_stat')
35+
# ### end Alembic commands ###

openadapt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"key_vk",
8484
"children",
8585
],
86+
"PLOT_PERFORMANCE": True,
8687
}
8788

8889
# each string in STOP_STRS should only contain strings that don't contain special characters

openadapt/crud.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Recording,
99
WindowEvent,
1010
PerformanceStat,
11+
MemoryStat
1112
)
1213
from openadapt.config import STOP_SEQUENCES
1314

@@ -18,6 +19,8 @@
1819
screenshots = []
1920
window_events = []
2021
performance_stats = []
22+
memory_stats = []
23+
2124

2225

2326
def _insert(event_data, table, buffer=None):
@@ -100,6 +103,33 @@ def get_perf_stats(recording_timestamp):
100103
)
101104

102105

106+
def insert_memory_stat(recording_timestamp, memory_usage_bytes, timestamp):
107+
"""
108+
Insert memory stat into db
109+
"""
110+
111+
memory_stat = {
112+
"recording_timestamp": recording_timestamp,
113+
"memory_usage_bytes": memory_usage_bytes,
114+
"timestamp": timestamp,
115+
}
116+
_insert(memory_stat, MemoryStat, memory_stats)
117+
118+
119+
def get_memory_stats(recording_timestamp):
120+
"""
121+
return memory stats for a given recording
122+
"""
123+
124+
return (
125+
db
126+
.query(MemoryStat)
127+
.filter(MemoryStat.recording_timestamp == recording_timestamp)
128+
.order_by(MemoryStat.timestamp)
129+
.all()
130+
)
131+
132+
103133
def insert_recording(recording_data):
104134
db_obj = Recording(**recording_data)
105135
db.add(db_obj)

openadapt/models.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,11 @@ def take_screenshot(cls):
269269
sct_img = utils.take_screenshot()
270270
screenshot = Screenshot(sct_img=sct_img)
271271
return screenshot
272-
272+
273273
def crop_active_window(self, action_event):
274274
window_event = action_event.window_event
275275
width_ratio, height_ratio = utils.get_scale_ratios(action_event)
276-
276+
277277
x0 = window_event.left * width_ratio
278278
y0 = window_event.top * height_ratio
279279
x1 = x0 + window_event.width * width_ratio
@@ -314,3 +314,12 @@ class PerformanceStat(db.Base):
314314
start_time = sa.Column(sa.Integer)
315315
end_time = sa.Column(sa.Integer)
316316
window_id = sa.Column(sa.String)
317+
318+
319+
class MemoryStat(db.Base):
320+
__tablename__ = "memory_stat"
321+
322+
id = sa.Column(sa.Integer, primary_key=True)
323+
recording_timestamp = sa.Column(sa.Integer)
324+
memory_usage_bytes = sa.Column(ForceFloat)
325+
timestamp = sa.Column(ForceFloat)

openadapt/record.py

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88

99
from collections import namedtuple
10-
from functools import partial
10+
from functools import partial, wraps
1111
from typing import Any, Callable, Dict
1212
import multiprocessing
1313
import os
@@ -16,16 +16,18 @@
1616
import sys
1717
import threading
1818
import time
19+
import tracemalloc
1920

2021
from loguru import logger
22+
from pympler import tracker
2123
from pynput import keyboard, mouse
2224
import fire
2325
import mss.tools
26+
import psutil
2427

2528
from openadapt import config, crud, utils, window
2629

27-
import functools
28-
30+
Event = namedtuple("Event", ("timestamp", "type", "data"))
2931

3032
EVENT_TYPES = ("screen", "action", "window")
3133
LOG_LEVEL = "INFO"
@@ -34,16 +36,39 @@
3436
"action": True,
3537
"window": True,
3638
}
37-
PLOT_PERFORMANCE = False
38-
39-
Event = namedtuple("Event", ("timestamp", "type", "data"))
40-
41-
global sequence_detected # Flag to indicate if a stop sequence is detected
39+
PLOT_PERFORMANCE = config.PLOT_PERFORMANCE
40+
NUM_MEMORY_STATS_TO_LOG = 3
4241
STOP_SEQUENCES = config.STOP_SEQUENCES
4342

43+
stop_sequence_detected = False
44+
performance_snapshots = []
45+
tracker = tracker.SummaryTracker()
46+
tracemalloc.start()
4447
utils.configure_logging(logger, LOG_LEVEL)
4548

4649

50+
def collect_stats():
51+
performance_snapshots.append(tracemalloc.take_snapshot())
52+
53+
54+
def log_memory_usage():
55+
assert len(performance_snapshots) == 2, performance_snapshots
56+
first_snapshot, last_snapshot = performance_snapshots
57+
stats = last_snapshot.compare_to(first_snapshot, "lineno")
58+
59+
for stat in stats[:NUM_MEMORY_STATS_TO_LOG]:
60+
new_KiB = stat.size_diff / 1024
61+
total_KiB = stat.size / 1024
62+
new_blocks = stat.count_diff
63+
total_blocks = stat.count
64+
source = stat.traceback.format()[0].strip()
65+
logger.info(f"{source=}")
66+
logger.info(f"\t{new_KiB=} {total_KiB=} {new_blocks=} {total_blocks=}")
67+
68+
trace_str = "\n".join(list(tracker.format_diff()))
69+
logger.info(f"trace_str=\n{trace_str}")
70+
71+
4772
def args_to_str(*args):
4873
return ", ".join(map(str, args))
4974

@@ -54,7 +79,7 @@ def kwargs_to_str(**kwargs):
5479

5580
def trace(logger):
5681
def decorator(func):
57-
@functools.wraps(func)
82+
@wraps(func)
5883
def wrapper_logging(*args, **kwargs):
5984
func_name = func.__qualname__
6085
func_args = args_to_str(*args)
@@ -160,6 +185,7 @@ def process_events(
160185
prev_saved_window_timestamp = prev_window_event.timestamp
161186
else:
162187
raise Exception(f"unhandled {event.type=}")
188+
del prev_event
163189
prev_event = event
164190
logger.info("done")
165191

@@ -470,6 +496,41 @@ def performance_stats_writer(
470496
logger.info("performance stats writer done")
471497

472498

499+
def memory_writer(
500+
recording_timestamp: float, terminate_event: multiprocessing.Event, record_pid: int
501+
):
502+
utils.configure_logging(logger, LOG_LEVEL)
503+
utils.set_start_time(recording_timestamp)
504+
logger.info("Memory writer starting")
505+
signal.signal(signal.SIGINT, signal.SIG_IGN)
506+
process = psutil.Process(record_pid)
507+
508+
while not terminate_event.is_set():
509+
memory_usage_bytes = 0
510+
511+
memory_info = process.memory_info()
512+
rss = memory_info.rss # Resident Set Size: non-swapped physical memory
513+
memory_usage_bytes += rss
514+
515+
for child in process.children(recursive=True):
516+
# after ctrl+c, children may terminate before the next line
517+
try:
518+
child_memory_info = child.memory_info()
519+
except psutil.NoSuchProcess:
520+
continue
521+
child_rss = child_memory_info.rss
522+
rss += child_rss
523+
524+
timestamp = utils.get_timestamp()
525+
526+
crud.insert_memory_stat(
527+
recording_timestamp,
528+
rss,
529+
timestamp,
530+
)
531+
logger.info("Memory writer done")
532+
533+
473534
@trace(logger)
474535
def create_recording(
475536
task_description: str,
@@ -521,7 +582,7 @@ def on_press(event_q, key, injected):
521582

522583
# stop sequence code
523584
nonlocal stop_sequence_indices
524-
global sequence_detected
585+
global stop_sequence_detected
525586
canonical_key_name = getattr(canonical_key, "name", None)
526587

527588
for i in range(0, len(STOP_SEQUENCES)):
@@ -547,7 +608,7 @@ def on_press(event_q, key, injected):
547608
# Check if the entire sequence has been entered correctly
548609
if stop_sequence_indices[i] == len(stop_sequence):
549610
logger.info("Stop sequence entered! Stopping recording now.")
550-
sequence_detected = True # Set global flag to end recording
611+
stop_sequence_detected = True
551612

552613
def on_release(event_q, key, injected):
553614
canonical_key = keyboard_listener.canonical(key)
@@ -694,19 +755,30 @@ def record(
694755
)
695756
perf_stat_writer.start()
696757

758+
if PLOT_PERFORMANCE:
759+
record_pid = os.getpid()
760+
mem_plotter = multiprocessing.Process(
761+
target=memory_writer,
762+
args=(recording_timestamp, terminate_perf_event, record_pid),
763+
)
764+
mem_plotter.start()
765+
697766
# TODO: discard events until everything is ready
698767

699-
global sequence_detected
700-
sequence_detected = False
768+
collect_stats()
769+
global stop_sequence_detected
701770

702771
try:
703-
while not sequence_detected:
772+
while not stop_sequence_detected:
704773
time.sleep(1)
705774

706775
terminate_event.set()
707776
except KeyboardInterrupt:
708777
terminate_event.set()
709778

779+
collect_stats()
780+
log_memory_usage()
781+
710782
logger.info(f"joining...")
711783
keyboard_event_reader.join()
712784
mouse_event_reader.join()
@@ -720,6 +792,7 @@ def record(
720792
terminate_perf_event.set()
721793

722794
if PLOT_PERFORMANCE:
795+
mem_plotter.join()
723796
utils.plot_performance(recording_timestamp)
724797

725798
logger.info(f"saved {recording_timestamp=}")

0 commit comments

Comments
 (0)