Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions experiments/describe_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Generate action descriptions."""

from pprint import pformat

from loguru import logger
import cv2
import numpy as np

from openadapt.db import crud


def embed_description(
image: np.ndarray,
description: str,
x: int = None,
y: int = None,
) -> np.ndarray:
"""Embed a description into an image at the specified location.

Args:
image (np.ndarray): The image to annotate.
description (str): The text to embed.
x (int, optional): The x-coordinate. Defaults to None (centered).
y (int, optional): The y-coordinate. Defaults to None (centered).

Returns:
np.ndarray: The annotated image.
"""
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255) # White
line_type = 1

# Split description into multiple lines
max_width = 60 # Maximum characters per line
words = description.split()
lines = []
current_line = []
for word in words:
if len(" ".join(current_line + [word])) <= max_width:
current_line.append(word)
else:
lines.append(" ".join(current_line))
current_line = [word]
if current_line:
lines.append(" ".join(current_line))

# Default to center if coordinates are not provided
if x is None or y is None:
x = image.shape[1] // 2
y = image.shape[0] // 2

# Draw semi-transparent background and text
for i, line in enumerate(lines):
text_size, _ = cv2.getTextSize(line, font, font_scale, line_type)
text_x = max(0, min(x - text_size[0] // 2, image.shape[1] - text_size[0]))
text_y = y + i * 20

# Draw background
cv2.rectangle(
image,
(text_x - 15, text_y - 25),
(text_x + text_size[0] + 15, text_y + 15),
(0, 0, 0),
-1,
)

# Draw text
cv2.putText(
image,
line,
(text_x, text_y),
font,
font_scale,
font_color,
line_type,
)

return image


def main() -> None:
"""Main function."""
with crud.get_new_session(read_only=True) as session:
recording = crud.get_latest_recording(session)
action_events = recording.processed_action_events
descriptions = []
for action in action_events:
description, image = action.prompt_for_description(return_image=True)

# Convert image to numpy array for OpenCV compatibility
image = np.array(image)

if action.mouse_x is not None and action.mouse_y is not None:
# Use the mouse coordinates for mouse events
annotated_image = embed_description(
image,
description,
x=int(action.mouse_x) * 2,
y=int(action.mouse_y) * 2,
)
else:
# Center the text for other events
annotated_image = embed_description(image, description)

logger.info(f"{action=}")
logger.info(f"{description=}")
cv2.imshow("Annotated Image", annotated_image)
cv2.waitKey(0)
descriptions.append(description)

logger.info(f"descriptions=\n{pformat(descriptions)}")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion openadapt/config.defaults.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"RECORD_READ_ACTIVE_ELEMENT_STATE": false,
"REPLAY_STRIP_ELEMENT_STATE": true,
"RECORD_VIDEO": true,
"RECORD_AUDIO": true,
"RECORD_AUDIO": false,
"RECORD_BROWSER_EVENTS": false,
"RECORD_FULL_VIDEO": false,
"RECORD_IMAGES": false,
Expand Down
20 changes: 11 additions & 9 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,18 @@ def get_all_scrubbed_recordings(


def get_latest_recording(session: SaSession) -> Recording:
"""Get the latest recording.

Args:
session (sa.orm.Session): The database session.

Returns:
Recording: The latest recording object.
"""
"""Get the latest recording with preloaded relationships."""
return (
session.query(Recording).order_by(sa.desc(Recording.timestamp)).limit(1).first()
session.query(Recording)
.options(
sa.orm.joinedload(Recording.screenshots),
sa.orm.joinedload(Recording.action_events)
.joinedload(ActionEvent.screenshot)
.joinedload(Screenshot.recording),
sa.orm.joinedload(Recording.window_events),
)
.order_by(sa.desc(Recording.timestamp))
.first()
)


Expand Down
10 changes: 6 additions & 4 deletions openadapt/drivers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from PIL import Image
import anthropic

from openadapt import cache, utils
from openadapt import cache
from openadapt.config import config
from openadapt.custom_logger import logger

MAX_TOKENS = 4096
# from https://docs.anthropic.com/claude/docs/vision
MAX_IMAGES = 20
MODEL_NAME = "claude-3-opus-20240229"
MODEL_NAME = "claude-3-5-sonnet-20241022"


@cache.cache()
Expand All @@ -24,6 +24,8 @@ def create_payload(
max_tokens: int | None = None,
) -> dict:
"""Creates the payload for the Anthropic API request with image support."""
from openadapt import utils

messages = []

user_message_content = []
Expand All @@ -36,7 +38,7 @@ def create_payload(
# Add base64 encoded images to the user message content
if images:
for image in images:
image_base64 = utils.image2utf8(image)
image_base64 = utils.image2utf8(image, "PNG")
# Extract media type and base64 data
# TODO: don't add it to begin with
media_type, image_base64_data = image_base64.split(";base64,", 1)
Expand Down Expand Up @@ -90,7 +92,7 @@ def get_completion(
"""Sends a request to the Anthropic API and returns the response."""
client = anthropic.Anthropic(api_key=api_key)
try:
response = client.messages.create(**payload)
response = client.beta.messages.create(**payload)
except Exception as exc:
logger.exception(exc)
if dev_mode:
Expand Down
80 changes: 80 additions & 0 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy
import io
import sys
import textwrap

from bs4 import BeautifulSoup
from pynput import keyboard
Expand All @@ -16,6 +17,7 @@

from openadapt.config import config
from openadapt.custom_logger import logger
from openadapt.drivers import anthropic
from openadapt.db import db
from openadapt.privacy.base import ScrubbingProvider, TextScrubbingMixin
from openadapt.privacy.providers import ScrubProvider
Expand Down Expand Up @@ -110,6 +112,9 @@ def processed_action_events(self) -> list:
if not self._processed_action_events:
session = crud.get_new_session(read_only=True)
self._processed_action_events = events.get_events(session, self)
# Preload screenshots to avoid lazy loading later
for event in self._processed_action_events:
event.screenshot
return self._processed_action_events

def scrub(self, scrubber: ScrubbingProvider) -> None:
Expand All @@ -125,6 +130,7 @@ class ActionEvent(db.Base):
"""Class representing an action event in the database."""

__tablename__ = "action_event"
_repr_ignore_attrs = ["reducer_names"]

_segment_description_separator = ";"

Expand Down Expand Up @@ -333,6 +339,11 @@ def canonical_text(self, value: str) -> None:
if not value == self.canonical_text:
logger.warning(f"{value=} did not match {self.canonical_text=}")

@property
def raw_text(self) -> str:
"""Return a string containing the raw action text (without separators)."""
return "".join(self.text.split(config.ACTION_TEXT_SEP))

def __str__(self) -> str:
"""Return a string representation of the action event."""
attr_names = [
Expand Down Expand Up @@ -544,6 +555,75 @@ def next_event(self) -> Union["ActionEvent", None]:

return None

def prompt_for_description(self, return_image: bool = False) -> str:
"""Use the Anthropic API to describe what is happening in the action event.

Args:
return_image (bool): Whether to return the image sent to the model.

Returns:
str: The description of the action event.
"""
from openadapt.plotting import display_event

image = display_event(
self,
marker_width_pct=0.05,
marker_height_pct=0.05,
darken_outside=0.7,
display_text=False,
marker_fill_transparency=0,
)

if self.text:
description = f"Type '{self.raw_text}'"
else:
prompt = (
"What user interface element is contained in the highlighted circle "
"of the image?"
)
# TODO: disambiguate
system_prompt = textwrap.dedent(
"""
Briefly describe the user interface element in the screenshot at the
highlighted location.
For example:
- "OK button"
- "URL bar"
- "Down arrow"
DO NOT DESCRIBE ANYTHING OUTSIDE THE HIGHLIGHTED AREA.
Do not append anything like "is contained within the highlighted circle
in the calculator interface." Just name the user interface element.
"""
)

logger.info(f"system_prompt=\n{system_prompt}")
logger.info(f"prompt=\n{prompt}")

# Call the Anthropic API
element = anthropic.prompt(
prompt=prompt,
system_prompt=system_prompt,
images=[image],
)

if self.name == "move":
description = f"Move mouse to '{element}'"
elif self.name == "scroll":
# TODO: "scroll to", dx/dy
description = f"Scroll mouse on '{element}'"
elif "click" in self.name:
description = (
f"{self.mouse_button_name.capitalize()} {self.name} '{element}'"
)
else:
raise ValueError(f"Unhandled {self.name=} {self}")

if return_image:
return description, image
else:
return description


class WindowEvent(db.Base):
"""Class representing a window event in the database."""
Expand Down
Loading
Loading