Skip to content

Commit bc43c05

Browse files
committed
wip
1 parent e595dd3 commit bc43c05

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed

experiments/describe_action.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from pprint import pformat
2+
3+
from loguru import logger
4+
5+
from openadapt.db import crud
6+
7+
8+
def main() -> None:
9+
session = crud.get_new_session(read_only=True)
10+
recording = crud.get_latest_recording(session)
11+
action_events = recording.processed_action_events
12+
descriptions = []
13+
for action in action_events:
14+
logger.info(f"{action=}")
15+
description = action.prompt_for_description()
16+
logger.info(f"{description=}")
17+
descriptions.append(description)
18+
logger.info(f"descriptions=\n{pformat(descriptions)}")
19+
20+
21+
if __name__ == "__main__":
22+
main()

openadapt/drivers/anthropic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from PIL import Image
66
import anthropic
77

8-
from openadapt import cache, utils
8+
from openadapt import cache
99
from openadapt.config import config
1010
from openadapt.custom_logger import logger
1111

@@ -148,3 +148,7 @@ def prompt(
148148
result = get_completion(payload)
149149
pprint(f"result=\n{result}") # Log result for debugging
150150
return result
151+
152+
153+
# avoid circular import
154+
from openadapt import utils # noqa

openadapt/models.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from openadapt.config import config
1818
from openadapt.custom_logger import logger
19+
from openadapt.drivers import anthropic
1920
from openadapt.db import db
2021
from openadapt.privacy.base import ScrubbingProvider, TextScrubbingMixin
2122
from openadapt.privacy.providers import ScrubProvider
@@ -544,6 +545,56 @@ def next_event(self) -> Union["ActionEvent", None]:
544545

545546
return None
546547

548+
def prompt_for_description(self) -> str:
549+
"""Use the Anthropic API to describe what is happening in the action event.
550+
551+
Returns:
552+
str: The description of the action event.
553+
"""
554+
# Collect the relevant information for the prompt
555+
prompt_details = {
556+
"name": self.name,
557+
"mouse_x": self.mouse_x,
558+
"mouse_y": self.mouse_y,
559+
"mouse_dx": self.mouse_dx,
560+
"mouse_dy": self.mouse_dy,
561+
"text": self.text,
562+
}
563+
564+
# Convert the details to a readable format for the prompt
565+
action_text = "\n".join(
566+
[
567+
f"{key}: {value}"
568+
for key, value in prompt_details.items() if value is not None
569+
]
570+
)
571+
prompt_text = f"Action:\n{action_text}"
572+
573+
# Add contextual instructions for the prompt
574+
system_prompt = (
575+
"You are an assistant tasked with describing user interactions in a graphical interface. "
576+
"Based on the following Action details and the provided screenshot, describe the action taking place "
577+
"as clearly as possible. Do not describe the screenshot in its entirety, only what is relevant for describing the action."
578+
)
579+
580+
logger.info(f"system_prompt=\n{system_prompt}")
581+
logger.info(f"prompt_text=\n{prompt_text}")
582+
583+
# Call the Anthropic API
584+
try:
585+
description = anthropic.prompt(
586+
prompt=prompt_text,
587+
system_prompt=system_prompt,
588+
images=[self.screenshot.image],
589+
max_tokens=256, # Adjust token limit as necessary
590+
)
591+
except Exception as e:
592+
logger.exception(f"Error while prompting for description: {e}")
593+
description = "An error occurred while generating the description."
594+
595+
return description
596+
597+
547598

548599
class WindowEvent(db.Base):
549600
"""Class representing a window event in the database."""

openadapt/plotting.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import matplotlib.pyplot as plt
1313
import numpy as np
1414

15-
from openadapt import common, contrib, models, utils
15+
from openadapt import common, models, utils
1616
from openadapt.config import PERFORMANCE_PLOTS_DIR_PATH, config
1717
from openadapt.custom_logger import logger
1818
from openadapt.models import ActionEvent
@@ -792,6 +792,8 @@ def get_marked_image(
792792
Image.Image: The marked image, where marks and/or masks are applied based on
793793
the specified confidence and IoU thresholds and the include flags.
794794
"""
795+
from openadapt import contrib
796+
795797
image_arr = np.asarray(original_image)
796798

797799
# The rest of this function is copied from

0 commit comments

Comments
 (0)