Skip to content

Commit b288c07

Browse files
authored
feat(action verification): implement replay action verification (#857)
* add is_action_event_complete * retry_with_exceptions in apply_replay_instructions * fix parse_code_snippet * add error_reporting.py * refactor video.py * black/flake8 * add module docstring * CHECK_ACTION_COMPLETE
1 parent 001c8fa commit b288c07

File tree

11 files changed

+225
-91
lines changed

11 files changed

+225
-91
lines changed

openadapt/config.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,13 @@
88
import os
99
import pathlib
1010
import shutil
11-
import webbrowser
1211

1312
from loguru import logger
1413
from pydantic import field_validator
1514
from pydantic.fields import FieldInfo
1615
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource
17-
from PySide6.QtWidgets import QMessageBox, QPushButton
18-
import git
19-
import sentry_sdk
2016

21-
from openadapt.build_utils import get_root_dir_path, is_running_from_executable
17+
from openadapt.build_utils import get_root_dir_path
2218

2319
CONFIG_DEFAULTS_FILE_PATH = (
2420
pathlib.Path(__file__).parent / "config.defaults.json"
@@ -411,47 +407,3 @@ def print_config() -> None:
411407
if not key.startswith("_") and key.isupper():
412408
val = maybe_obfuscate(key, val)
413409
logger.info(f"{key}={val}")
414-
415-
if config.ERROR_REPORTING_ENABLED:
416-
if is_running_from_executable():
417-
is_reporting_branch = True
418-
else:
419-
active_branch_name = git.Repo(PARENT_DIR_PATH).active_branch.name
420-
logger.info(f"{active_branch_name=}")
421-
is_reporting_branch = (
422-
active_branch_name == config.ERROR_REPORTING_BRANCH
423-
)
424-
logger.info(f"{is_reporting_branch=}")
425-
if is_reporting_branch:
426-
427-
def show_alert() -> None:
428-
"""Show an alert to the user."""
429-
msg = QMessageBox()
430-
msg.setIcon(QMessageBox.Warning)
431-
msg.setText("""
432-
An error has occurred. The development team has been notified.
433-
Please join the discord server to get help or send an email to
434-
435-
""")
436-
discord_button = QPushButton("Join the discord server")
437-
discord_button.clicked.connect(
438-
lambda: webbrowser.open("https://discord.gg/yF527cQbDG")
439-
)
440-
msg.addButton(discord_button, QMessageBox.ActionRole)
441-
msg.addButton(QMessageBox.Ok)
442-
msg.exec()
443-
444-
def before_send_event(event: Any, hint: Any) -> Any:
445-
"""Handle the event before sending it to Sentry."""
446-
try:
447-
show_alert()
448-
except Exception:
449-
pass
450-
return event
451-
452-
sentry_sdk.init(
453-
dsn=config.ERROR_REPORTING_DSN,
454-
traces_sample_rate=1.0,
455-
before_send=before_send_event,
456-
ignore_errors=[KeyboardInterrupt],
457-
)

openadapt/entrypoint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
multiprocessing.freeze_support()
88

99
from openadapt.build_utils import redirect_stdout_stderr
10+
from openadapt.error_reporting import configure_error_reporting
1011
from openadapt.custom_logger import logger
1112

1213

@@ -19,6 +20,7 @@ def run_openadapt() -> None:
1920
from openadapt.config import print_config
2021

2122
print_config()
23+
configure_error_reporting()
2224
load_alembic_context()
2325
tray._run()
2426
except Exception as exc:

openadapt/error_reporting.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Module for error reporting logic."""
2+
3+
from typing import Any
4+
5+
from loguru import logger
6+
from PySide6.QtGui import QIcon
7+
from PySide6.QtWidgets import QMessageBox, QPushButton
8+
import git
9+
import sentry_sdk
10+
import webbrowser
11+
12+
from openadapt.build_utils import is_running_from_executable
13+
from openadapt.config import PARENT_DIR_PATH, config
14+
15+
16+
def configure_error_reporting() -> None:
17+
"""Configure error reporting."""
18+
logger.info(f"{config.ERROR_REPORTING_ENABLED=}")
19+
if not config.ERROR_REPORTING_ENABLED:
20+
return
21+
22+
if is_running_from_executable():
23+
is_reporting_branch = True
24+
else:
25+
active_branch_name = git.Repo(PARENT_DIR_PATH).active_branch.name
26+
logger.info(f"{active_branch_name=}")
27+
is_reporting_branch = active_branch_name == config.ERROR_REPORTING_BRANCH
28+
logger.info(f"{is_reporting_branch=}")
29+
30+
if is_reporting_branch:
31+
sentry_sdk.init(
32+
dsn=config.ERROR_REPORTING_DSN,
33+
traces_sample_rate=1.0,
34+
before_send=before_send_event,
35+
ignore_errors=[KeyboardInterrupt],
36+
)
37+
38+
39+
def show_alert() -> None:
40+
"""Show an alert to the user."""
41+
# TODO: move to config
42+
from openadapt.app.tray import ICON_PATH
43+
44+
msg = QMessageBox()
45+
msg.setIcon(QMessageBox.Warning)
46+
msg.setWindowIcon(QIcon(ICON_PATH))
47+
msg.setText("""
48+
An error has occurred. The development team has been notified.
49+
Please join the discord server to get help or send an email to
50+
51+
""")
52+
discord_button = QPushButton("Join the discord server")
53+
discord_button.clicked.connect(
54+
lambda: webbrowser.open("https://discord.gg/yF527cQbDG")
55+
)
56+
msg.addButton(discord_button, QMessageBox.ActionRole)
57+
msg.addButton(QMessageBox.Ok)
58+
msg.exec()
59+
60+
61+
def before_send_event(event: Any, hint: Any) -> Any:
62+
"""Handle the event before sending it to Sentry."""
63+
try:
64+
show_alert()
65+
except Exception:
66+
pass
67+
return event

openadapt/models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,14 @@ def to_prompt_dict(self, include_data: bool = True) -> dict[str, Any]:
556556
if "state" in window_dict:
557557
if include_data:
558558
key_suffixes = [
559-
"value", "h", "w", "x", "y", "description", "title", "help",
559+
"value",
560+
"h",
561+
"w",
562+
"x",
563+
"y",
564+
"description",
565+
"title",
566+
"help",
560567
]
561568
if sys.platform == "win32":
562569
logger.warning(

openadapt/prompts/apply_replay_instructions.j2

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,13 @@ Do NOT provide available_segment_descriptions in your response.
1616

1717
Respond with json and nothing else.
1818

19+
{% if exceptions.length %}
20+
Your previous attempts at this produced the following exceptions:
21+
{% for exception in exceptions %}
22+
<exception>
23+
{{ exception }}
24+
</exception>
25+
{% endfor %}
26+
{% endif %}
27+
1928
My career depends on this. Lives are at stake.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
Consider the actions that you previously generated:
2+
3+
```json
4+
{{ actions }}
5+
```
6+
7+
The attached image is a screenshot of the current state of the system, immediately
8+
after the last action in the sequence was played.
9+
10+
Your task is to:
11+
1. Describe what you would expect to see in the screenshot after the last action in the
12+
sequence is complete, and
13+
2. Determine whether the the last action has completed by looking at the attached
14+
screenshot. For example, if you expect that the sequence of actions would result in
15+
opening a particular application, you should determine whether that application has
16+
finished opening.
17+
18+
Respond with JSON and nothing else. The JSON should have the following keys:
19+
- "expected_state": Natural language description of what you would expect to see.
20+
- "is_complete": Boolean indicating whether the last action is complete or not.
21+
22+
My career depends on this. Lives are at stake.

openadapt/replay.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from openadapt import utils
2020
from openadapt.config import CAPTURE_DIR_PATH, print_config
2121
from openadapt.db import crud
22+
from openadapt.error_reporting import configure_error_reporting
2223
from openadapt.models import Recording
2324

2425
LOG_LEVEL = "INFO"
@@ -50,6 +51,7 @@ def replay(
5051
"""
5152
utils.configure_logging(logger, LOG_LEVEL)
5253
print_config()
54+
configure_error_reporting()
5355
posthog.capture(event="replay.started", properties={"strategy_name": strategy_name})
5456

5557
if status_pipe:

openadapt/strategies/base.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from oa_pynput import keyboard, mouse
88
import numpy as np
99

10-
from openadapt import models, playback, utils
10+
from openadapt import adapters, models, playback, utils
1111
from openadapt.custom_logger import logger
1212

13+
CHECK_ACTION_COMPLETE = True
1314
MAX_FRAME_TIMES = 1000
1415

1516

@@ -55,6 +56,16 @@ def run(self) -> None:
5556
mouse_controller = mouse.Controller()
5657
while True:
5758
screenshot = models.Screenshot.take_screenshot()
59+
60+
# check if previous action is complete
61+
if CHECK_ACTION_COMPLETE:
62+
is_action_complete = prompt_is_action_complete(
63+
screenshot,
64+
self.action_events,
65+
)
66+
if not is_action_complete:
67+
continue
68+
5869
self.screenshots.append(screenshot)
5970
window_event = models.WindowEvent.get_active_window_event()
6071
self.window_events.append(window_event)
@@ -108,3 +119,42 @@ def log_fps(self) -> None:
108119
logger.info(f"{fps=:.2f}")
109120
if len(self.frame_times) > self.max_frame_times:
110121
self.frame_times.pop(0)
122+
123+
124+
def prompt_is_action_complete(
125+
current_screenshot: models.Screenshot,
126+
played_actions: list[models.ActionEvent],
127+
) -> bool:
128+
"""Determine whether the the last action is complete.
129+
130+
Args:
131+
current_screenshot (models.Screenshot): The current Screenshot.
132+
played_actions (list[models.ActionEvent]: The list of previously played
133+
ActionEvents.
134+
135+
Returns:
136+
(bool) whether or not the last played action has completed.
137+
"""
138+
if not played_actions:
139+
return True
140+
system_prompt = utils.render_template_from_file(
141+
"prompts/system.j2",
142+
)
143+
actions_dict = {
144+
"actions": [action.to_prompt_dict() for action in played_actions],
145+
}
146+
prompt = utils.render_template_from_file(
147+
"prompts/is_action_complete.j2",
148+
actions=actions_dict,
149+
)
150+
prompt_adapter = adapters.get_default_prompt_adapter()
151+
content = prompt_adapter.prompt(
152+
prompt,
153+
system_prompt=system_prompt,
154+
images=[current_screenshot.image],
155+
)
156+
content_dict = utils.parse_code_snippet(content)
157+
expected_state = content_dict["expected_state"]
158+
is_complete = content_dict["is_complete"]
159+
logger.info(f"{expected_state=} {is_complete=}")
160+
return is_complete

openadapt/strategies/visual.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,19 @@ def add_active_segment_descriptions(action_events: list[models.ActionEvent]) ->
111111
action.available_segment_descriptions = window_segmentation.descriptions
112112

113113

114+
@utils.retry_with_exceptions()
114115
def apply_replay_instructions(
115116
action_events: list[models.ActionEvent],
116117
replay_instructions: str,
117-
# retain_window_events: bool = False,
118+
exceptions: list[Exception],
118119
) -> None:
119120
"""Modify the given ActionEvents according to the given replay instructions.
120121
121122
Args:
122123
action_events: list of action events to be modified in place.
123124
replay_instructions: instructions for how action events should be modified.
125+
exceptions: list of exceptions that were produced attempting to run this
126+
function.
124127
"""
125128
action_dicts = [action.to_prompt_dict() for action in action_events]
126129
actions_dict = {"actions": action_dicts}
@@ -131,6 +134,7 @@ def apply_replay_instructions(
131134
"prompts/apply_replay_instructions.j2",
132135
actions=actions_dict,
133136
replay_instructions=replay_instructions,
137+
exceptions=exceptions,
134138
)
135139
prompt_adapter = adapters.get_default_prompt_adapter()
136140
content = prompt_adapter.prompt(

0 commit comments

Comments
 (0)