Skip to content

Commit 037826c

Browse files
committed
add ActionEvent.prompt_for_description; fix plotting; add experiments/describe_action.py; RECORD_AUDIO false; get_latest_recording joinedload; anthropic.py MODEL_NAME claude-3-5-sonnet-20241022; image2utf8 PNG; python<3.12
1 parent e595dd3 commit 037826c

File tree

9 files changed

+279
-51
lines changed

9 files changed

+279
-51
lines changed

experiments/describe_action.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from pprint import pformat
2+
3+
from loguru import logger
4+
import cv2
5+
import numpy as np
6+
7+
from openadapt.db import crud
8+
9+
10+
def embed_description(image, description, x=None, y=None):
11+
"""Embed a description into an image at the specified location.
12+
13+
Args:
14+
image (np.ndarray): The image to annotate.
15+
description (str): The text to embed.
16+
x (int, optional): The x-coordinate. Defaults to None (centered).
17+
y (int, optional): The y-coordinate. Defaults to None (centered).
18+
19+
Returns:
20+
np.ndarray: The annotated image.
21+
"""
22+
font = cv2.FONT_HERSHEY_SIMPLEX
23+
font_scale = 1
24+
font_color = (255, 255, 255) # White
25+
line_type = 1
26+
27+
# Split description into multiple lines
28+
max_width = 60 # Maximum characters per line
29+
words = description.split()
30+
lines = []
31+
current_line = []
32+
for word in words:
33+
if len(" ".join(current_line + [word])) <= max_width:
34+
current_line.append(word)
35+
else:
36+
lines.append(" ".join(current_line))
37+
current_line = [word]
38+
if current_line:
39+
lines.append(" ".join(current_line))
40+
41+
# Default to center if coordinates are not provided
42+
if x is None or y is None:
43+
x = image.shape[1] // 2
44+
y = image.shape[0] // 2
45+
46+
# Draw semi-transparent background and text
47+
for i, line in enumerate(lines):
48+
text_size, _ = cv2.getTextSize(line, font, font_scale, line_type)
49+
text_x = max(0, min(x - text_size[0] // 2, image.shape[1] - text_size[0]))
50+
text_y = y + i * 20
51+
52+
# Draw background
53+
cv2.rectangle(
54+
image,
55+
(text_x - 15, text_y - 25),
56+
(text_x + text_size[0] + 15, text_y + 15),
57+
(0, 0, 0),
58+
-1,
59+
)
60+
61+
# Draw text
62+
cv2.putText(
63+
image,
64+
line,
65+
(text_x, text_y),
66+
font,
67+
font_scale,
68+
font_color,
69+
line_type,
70+
)
71+
72+
return image
73+
74+
75+
def main() -> None:
76+
with crud.get_new_session(read_only=True) as session:
77+
recording = crud.get_latest_recording(session)
78+
action_events = recording.processed_action_events
79+
descriptions = []
80+
for action in action_events:
81+
description, image = action.prompt_for_description(return_image=True)
82+
83+
# Convert image to numpy array for OpenCV compatibility
84+
image = np.array(image)
85+
86+
if action.mouse_x is not None and action.mouse_y is not None:
87+
# Use the mouse coordinates for mouse events
88+
annotated_image = embed_description(
89+
image,
90+
description,
91+
x=int(action.mouse_x) * 2,
92+
y=int(action.mouse_y) * 2,
93+
)
94+
else:
95+
# Center the text for other events
96+
annotated_image = embed_description(image, description)
97+
98+
logger.info(f"{action=}")
99+
logger.info(f"{description=}")
100+
cv2.imshow("Annotated Image", annotated_image)
101+
cv2.waitKey(0)
102+
descriptions.append(description)
103+
104+
logger.info(f"descriptions=\n{pformat(descriptions)}")
105+
106+
107+
if __name__ == "__main__":
108+
main()

openadapt/config.defaults.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"RECORD_READ_ACTIVE_ELEMENT_STATE": false,
1818
"REPLAY_STRIP_ELEMENT_STATE": true,
1919
"RECORD_VIDEO": true,
20-
"RECORD_AUDIO": true,
20+
"RECORD_AUDIO": false,
2121
"RECORD_BROWSER_EVENTS": false,
2222
"RECORD_FULL_VIDEO": false,
2323
"RECORD_IMAGES": false,

openadapt/db/crud.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,18 @@ def get_all_scrubbed_recordings(
337337

338338

339339
def get_latest_recording(session: SaSession) -> Recording:
340-
"""Get the latest recording.
341-
342-
Args:
343-
session (sa.orm.Session): The database session.
344-
345-
Returns:
346-
Recording: The latest recording object.
347-
"""
340+
"""Get the latest recording with preloaded relationships."""
348341
return (
349-
session.query(Recording).order_by(sa.desc(Recording.timestamp)).limit(1).first()
342+
session.query(Recording)
343+
.options(
344+
sa.orm.joinedload(Recording.screenshots),
345+
sa.orm.joinedload(Recording.action_events)
346+
.joinedload(ActionEvent.screenshot)
347+
.joinedload(Screenshot.recording),
348+
sa.orm.joinedload(Recording.window_events),
349+
)
350+
.order_by(sa.desc(Recording.timestamp))
351+
.first()
350352
)
351353

352354

openadapt/drivers/anthropic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
from PIL import Image
66
import anthropic
77

8-
from openadapt import cache, utils
8+
from openadapt import cache
99
from openadapt.config import config
1010
from openadapt.custom_logger import logger
1111

1212
MAX_TOKENS = 4096
1313
# from https://docs.anthropic.com/claude/docs/vision
1414
MAX_IMAGES = 20
15-
MODEL_NAME = "claude-3-opus-20240229"
15+
MODEL_NAME = "claude-3-5-sonnet-20241022"
1616

1717

1818
@cache.cache()
@@ -24,6 +24,8 @@ def create_payload(
2424
max_tokens: int | None = None,
2525
) -> dict:
2626
"""Creates the payload for the Anthropic API request with image support."""
27+
from openadapt import utils
28+
2729
messages = []
2830

2931
user_message_content = []
@@ -36,7 +38,7 @@ def create_payload(
3638
# Add base64 encoded images to the user message content
3739
if images:
3840
for image in images:
39-
image_base64 = utils.image2utf8(image)
41+
image_base64 = utils.image2utf8(image, "PNG")
4042
# Extract media type and base64 data
4143
# TODO: don't add it to begin with
4244
media_type, image_base64_data = image_base64.split(";base64,", 1)
@@ -90,7 +92,7 @@ def get_completion(
9092
"""Sends a request to the Anthropic API and returns the response."""
9193
client = anthropic.Anthropic(api_key=api_key)
9294
try:
93-
response = client.messages.create(**payload)
95+
response = client.beta.messages.create(**payload)
9496
except Exception as exc:
9597
logger.exception(exc)
9698
if dev_mode:

openadapt/models.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import copy
88
import io
99
import sys
10+
import textwrap
1011

1112
from bs4 import BeautifulSoup
1213
from pynput import keyboard
@@ -16,6 +17,7 @@
1617

1718
from openadapt.config import config
1819
from openadapt.custom_logger import logger
20+
from openadapt.drivers import anthropic
1921
from openadapt.db import db
2022
from openadapt.privacy.base import ScrubbingProvider, TextScrubbingMixin
2123
from openadapt.privacy.providers import ScrubProvider
@@ -110,6 +112,9 @@ def processed_action_events(self) -> list:
110112
if not self._processed_action_events:
111113
session = crud.get_new_session(read_only=True)
112114
self._processed_action_events = events.get_events(session, self)
115+
# Preload screenshots to avoid lazy loading later
116+
for event in self._processed_action_events:
117+
event.screenshot
113118
return self._processed_action_events
114119

115120
def scrub(self, scrubber: ScrubbingProvider) -> None:
@@ -125,6 +130,7 @@ class ActionEvent(db.Base):
125130
"""Class representing an action event in the database."""
126131

127132
__tablename__ = "action_event"
133+
_repr_ignore_attrs = ["reducer_names"]
128134

129135
_segment_description_separator = ";"
130136

@@ -333,6 +339,10 @@ def canonical_text(self, value: str) -> None:
333339
if not value == self.canonical_text:
334340
logger.warning(f"{value=} did not match {self.canonical_text=}")
335341

342+
@property
343+
def raw_text(self) -> str:
344+
return "".join(self.text.split(config.ACTION_TEXT_SEP))
345+
336346
def __str__(self) -> str:
337347
"""Return a string representation of the action event."""
338348
attr_names = [
@@ -544,6 +554,75 @@ def next_event(self) -> Union["ActionEvent", None]:
544554

545555
return None
546556

557+
def prompt_for_description(self, return_image: bool = False) -> str:
558+
"""Use the Anthropic API to describe what is happening in the action event.
559+
560+
Args:
561+
return_image (bool): Whether to return the image sent to the model.
562+
563+
Returns:
564+
str: The description of the action event.
565+
"""
566+
from openadapt.plotting import display_event
567+
568+
image = display_event(
569+
self,
570+
marker_width_pct=0.05,
571+
marker_height_pct=0.05,
572+
darken_outside=0.7,
573+
display_text=False,
574+
marker_fill_transparency=0,
575+
)
576+
577+
if self.text:
578+
description = f"Type '{self.raw_text}'"
579+
else:
580+
prompt = (
581+
"What user interface element is contained in the highlighted circle "
582+
"of the image?"
583+
)
584+
# TODO: disambiguate
585+
system_prompt = textwrap.dedent(
586+
"""
587+
Briefly describe the user interface element in the screenshot at the
588+
highlighted location.
589+
For example:
590+
- "OK button"
591+
- "URL bar"
592+
- "Down arrow"
593+
DO NOT DESCRIBE ANYTHING OUTSIDE THE HIGHLIGHTED AREA.
594+
Do not append anything like "is contained within the highlighted circle
595+
in the calculator interface." Just name the user interface element.
596+
"""
597+
)
598+
599+
logger.info(f"system_prompt=\n{system_prompt}")
600+
logger.info(f"prompt=\n{prompt}")
601+
602+
# Call the Anthropic API
603+
element = anthropic.prompt(
604+
prompt=prompt,
605+
system_prompt=system_prompt,
606+
images=[image],
607+
)
608+
609+
if self.name == "move":
610+
description = f"Move mouse to '{element}'"
611+
elif self.name == "scroll":
612+
# TODO: "scroll to", dx/dy
613+
description = f"Scroll mouse on '{element}'"
614+
elif "click" in self.name:
615+
description = (
616+
f"{self.mouse_button_name.capitalize()} {self.name} '{element}'"
617+
)
618+
else:
619+
raise ValueError(f"Unhandled {self.name=} {self}")
620+
621+
if return_image:
622+
return description, image
623+
else:
624+
return description
625+
547626

548627
class WindowEvent(db.Base):
549628
"""Class representing a window event in the database."""

0 commit comments

Comments
 (0)