Skip to content

Commit 385963c

Browse files
authored
feat: stop listener
* created listener for "oa.stop" sequence * fixed issue with comparing objects of diff types * moved list of sequences to config.STOP_SEQUENCES and changed code to accomadate multiple stop sequences, + minor changes to naming and logging * moved list of stop sequences to config.STOP_SEQUENCES * filter out stop sequence in crud.get_action_events * combined keyboard listeners for macOS compatability * style changes * code cleanup * special char support * change to config.STOP_STRS and split by character in record.py and crud.py * black * add todo and fix special char functionality * fix filter_stop_sequences * added SPECIAL_CHAR_STOP_SEQUENCES and STOP_SEQUENCES that combines STOP_STRS and SPECIAL_CHAR_STOP_SEQUENCES * STOP_SEQUENCES moved to config.py * black * black
1 parent 5b9f735 commit 385963c

File tree

3 files changed

+133
-26
lines changed

3 files changed

+133
-26
lines changed

openadapt/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,20 @@
8585
],
8686
}
8787

88+
# each string in STOP_STRS should only contain strings that don't contain special characters
89+
STOP_STRS = [
90+
"oa.stop",
91+
# TODO:
92+
# "<ctrl>+c,<ctrl>+c,<ctrl>+c"
93+
]
94+
# each list in SPECIAL_CHAR_STOP_SEQUENCES should contain sequences
95+
# containing special chars, separated by keys
96+
SPECIAL_CHAR_STOP_SEQUENCES = [["ctrl", "ctrl", "ctrl"]]
97+
# sequences that when typed, will stop the recording of ActionEvents in record.py
98+
STOP_SEQUENCES = [
99+
list(stop_str) for stop_str in STOP_STRS
100+
] + SPECIAL_CHAR_STOP_SEQUENCES
101+
88102

89103
def getenv_fallback(var_name):
90104
rval = os.getenv(var_name) or _DEFAULTS.get(var_name)

openadapt/crud.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
WindowEvent,
1010
PerformanceStat,
1111
)
12-
12+
from openadapt.config import STOP_SEQUENCES
1313

1414
BATCH_SIZE = 1
1515

@@ -19,13 +19,11 @@
1919
window_events = []
2020
performance_stats = []
2121

22+
2223
def _insert(event_data, table, buffer=None):
2324
"""Insert using Core API for improved performance (no rows are returned)"""
2425

25-
db_obj = {
26-
column.name: None
27-
for column in table.__table__.columns
28-
}
26+
db_obj = {column.name: None for column in table.__table__.columns}
2927
for key in db_obj:
3028
if key in event_data:
3129
val = event_data[key]
@@ -74,6 +72,7 @@ def insert_window_event(recording_timestamp, event_timestamp, event_data):
7472
}
7573
_insert(event_data, WindowEvent, window_events)
7674

75+
7776
def insert_perf_stat(recording_timestamp, event_type, start_time, end_time):
7877
"""
7978
Insert event performance stat into db
@@ -87,19 +86,20 @@ def insert_perf_stat(recording_timestamp, event_type, start_time, end_time):
8786
}
8887
_insert(event_perf_stat, PerformanceStat, performance_stats)
8988

89+
9090
def get_perf_stats(recording_timestamp):
9191
"""
9292
return performance stats for a given recording
9393
"""
9494

9595
return (
96-
db
97-
.query(PerformanceStat)
96+
db.query(PerformanceStat)
9897
.filter(PerformanceStat.recording_timestamp == recording_timestamp)
9998
.order_by(PerformanceStat.start_time)
10099
.all()
101100
)
102101

102+
103103
def insert_recording(recording_data):
104104
db_obj = Recording(**recording_data)
105105
db.add(db_obj)
@@ -109,36 +109,87 @@ def insert_recording(recording_data):
109109

110110

111111
def get_latest_recording():
112-
return (
113-
db
114-
.query(Recording)
115-
.order_by(sa.desc(Recording.timestamp))
116-
.limit(1)
117-
.first()
118-
)
112+
return db.query(Recording).order_by(sa.desc(Recording.timestamp)).limit(1).first()
119113

120114

121115
def get_recording(timestamp):
122-
return (
123-
db
124-
.query(Recording)
125-
.filter(Recording.timestamp == timestamp)
126-
.first()
127-
)
116+
return db.query(Recording).filter(Recording.timestamp == timestamp).first()
128117

129118

130119
def _get(table, recording_timestamp):
131120
return (
132-
db
133-
.query(table)
121+
db.query(table)
134122
.filter(table.recording_timestamp == recording_timestamp)
135123
.order_by(table.timestamp)
136124
.all()
137125
)
138126

139127

140128
def get_action_events(recording):
141-
return _get(ActionEvent, recording.timestamp)
129+
action_events = _get(ActionEvent, recording.timestamp)
130+
# filter out stop sequences listed in STOP_SEQUENCES and Ctrl + C
131+
filter_stop_sequences(action_events)
132+
return action_events
133+
134+
135+
def filter_stop_sequences(action_events):
136+
# check for ctrl c first
137+
# TODO: want to handle sequences like ctrl c the same way as normal sequences
138+
if len(action_events) >= 2:
139+
if (
140+
action_events[-1].canonical_key_char == "c"
141+
and action_events[-2].canonical_key_name == "ctrl"
142+
):
143+
# remove ctrl c
144+
# ctrl c must be held down at same time, so no release event
145+
action_events.pop()
146+
action_events.pop()
147+
return
148+
149+
# create list of indices for sequence detection
150+
# one index for each stop sequence in STOP_SEQUENCES
151+
# start from the back of the sequence
152+
stop_sequence_indices = [len(sequence) - 1 for sequence in STOP_SEQUENCES]
153+
154+
# index of sequence to remove, -1 if none found
155+
sequence_to_remove = -1
156+
# number of events to remove
157+
num_to_remove = 0
158+
159+
for i in range(0, len(STOP_SEQUENCES)):
160+
# iterate backwards through list of action events
161+
for j in range(len(action_events) - 1, -1, -1):
162+
# never go past 1st action event, so if a sequence is longer than
163+
# len(action_events), it can't have been in the recording
164+
if (
165+
action_events[j].canonical_key_char
166+
== STOP_SEQUENCES[i][stop_sequence_indices[i]]
167+
or action_events[j].canonical_key_name
168+
== STOP_SEQUENCES[i][stop_sequence_indices[i]]
169+
) and action_events[j].name == "press":
170+
# for press events, compare the characters
171+
stop_sequence_indices[i] -= 1
172+
num_to_remove += 1
173+
elif action_events[j].name == "release" and (
174+
action_events[j].canonical_key_char in STOP_SEQUENCES[i]
175+
or action_events[j].canonical_key_name in STOP_SEQUENCES[i]
176+
):
177+
# can consider any release event with any sequence char as part of the sequence
178+
num_to_remove += 1
179+
else:
180+
# not part of the sequence, so exit inner loop
181+
break
182+
183+
if stop_sequence_indices[i] == -1:
184+
# completed whole sequence, so set sequence_to_remove to
185+
# current sequence and exit outer loop
186+
sequence_to_remove = i
187+
break
188+
189+
if sequence_to_remove != -1:
190+
# remove that sequence
191+
for _ in range(0, num_to_remove):
192+
action_events.pop()
142193

143194

144195
def get_screenshots(recording, precompute_diffs=False):

openadapt/record.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636
}
3737
PLOT_PERFORMANCE = False
3838

39-
4039
Event = namedtuple("Event", ("timestamp", "type", "data"))
4140

41+
global sequence_detected # Flag to indicate if a stop sequence is detected
42+
STOP_SEQUENCES = config.STOP_SEQUENCES
43+
4244
utils.configure_logging(logger, LOG_LEVEL)
4345

4446

@@ -407,6 +409,7 @@ def read_window_events(
407409
window_data = window.get_active_window_data()
408410
if not window_data:
409411
continue
412+
410413
if window_data["title"] != prev_window_data.get("title") or window_data[
411414
"window_id"
412415
] != prev_window_data.get("window_id"):
@@ -433,7 +436,7 @@ def read_window_events(
433436

434437

435438
@trace(logger)
436-
def performance_stats_writer (
439+
def performance_stats_writer(
437440
perf_q: multiprocessing.Queue,
438441
recording_timestamp: float,
439442
terminate_event: multiprocessing.Event,
@@ -506,12 +509,46 @@ def read_keyboard_events(
506509
terminate_event: multiprocessing.Event,
507510
recording_timestamp: float,
508511
) -> None:
512+
# create list of indices for sequence detection
513+
# one index for each stop sequence in STOP_SEQUENCES
514+
stop_sequence_indices = [0 for _ in STOP_SEQUENCES]
515+
509516
def on_press(event_q, key, injected):
510517
canonical_key = keyboard_listener.canonical(key)
511518
logger.debug(f"{key=} {injected=} {canonical_key=}")
512519
if not injected:
513520
handle_key(event_q, "press", key, canonical_key)
514521

522+
# stop sequence code
523+
nonlocal stop_sequence_indices
524+
global sequence_detected
525+
canonical_key_name = getattr(canonical_key, "name", None)
526+
527+
for i in range(0, len(STOP_SEQUENCES)):
528+
# check each stop sequence
529+
stop_sequence = STOP_SEQUENCES[i]
530+
# stop_sequence_indices[i] is the index for this stop sequence
531+
# get canonical KeyCode of current letter in this sequence
532+
canonical_sequence = keyboard_listener.canonical(
533+
keyboard.KeyCode.from_char(stop_sequence[stop_sequence_indices[i]])
534+
)
535+
536+
# Check if the pressed key matches the current key in this sequence
537+
if (
538+
canonical_key == canonical_sequence
539+
or canonical_key_name == stop_sequence[stop_sequence_indices[i]]
540+
):
541+
# increment this index
542+
stop_sequence_indices[i] += 1
543+
else:
544+
# Reset index since pressed key doesn't match sequence key
545+
stop_sequence_indices[i] = 0
546+
547+
# Check if the entire sequence has been entered correctly
548+
if stop_sequence_indices[i] == len(stop_sequence):
549+
logger.info("Stop sequence entered! Stopping recording now.")
550+
sequence_detected = True # Set global flag to end recording
551+
515552
def on_release(event_q, key, injected):
516553
canonical_key = keyboard_listener.canonical(key)
517554
logger.debug(f"{key=} {injected=} {canonical_key=}")
@@ -659,9 +696,14 @@ def record(
659696

660697
# TODO: discard events until everything is ready
661698

699+
global sequence_detected
700+
sequence_detected = False
701+
662702
try:
663-
while True:
703+
while not sequence_detected:
664704
time.sleep(1)
705+
706+
terminate_event.set()
665707
except KeyboardInterrupt:
666708
terminate_event.set()
667709

0 commit comments

Comments
 (0)