Skip to content

Commit 81074c4

Browse files
authored
feat(vanilla): Implement vanilla baseline (#702)
* add vanilla.py; update README.md * update module docstring * add TODO; remove unused code * tabs -> spaces * modify prompt adapter API to accept images instead of base64 images * fix typo * add ActionEvent/WindowEvent.to_prompt_dict * add prompts/describe_recording.j2; prompts/generate_action_event.j2 * fixes; log action history * fixes; black; flake8 * add tests/openadapt/adapters * add missing test assets * DATA_DIR_PATH -> PARENT_DIR_PATH * get_completion dev_mode * anthropic dev_mode * fix tests * black * flake8 * ignore .cache in flake8 * vanilla.INCLUDE_WINDOW_DATA; utils.clean_data/filter_keys * flake8
1 parent 53f40fb commit 81074c4

File tree

23 files changed

+643
-158
lines changed

23 files changed

+643
-158
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,4 @@ jobs:
6363
run: poetry run black --preview --check . --exclude '/(alembic|\.venv)/'
6464

6565
- name: Run Flake8
66-
run: poetry run flake8 --exclude=alembic,.venv
66+
run: poetry run flake8 --exclude=alembic,.venv,*/.cache

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ python -m openadapt.replay NaiveReplayStrategy
191191
Other replay strategies include:
192192

193193
- [`StatefulReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/stateful.py): Proof-of-concept which uses the OpenAI GPT-4 API with prompts constructed via OS-level window data.
194+
- [`VanillaReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/vanilla.py): If AGI or GPT6 happens, this script should be able to suddenly do the work. --LunjunZhang
194195
- [`VisualReplayStrategy`](https://github.com/OpenAdaptAI/OpenAdapt/blob/main/openadapt/strategies/visual.py): Uses [Fast Segment Anything Model (FastSAM)](https://github.com/CASIA-IVA-Lab/FastSAM) to segment active window. Accepts an "instructions" parameter that is used to modify the recording, e.g.:
195196

196197
```

openadapt/adapters/anthropic.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from pprint import pprint
44

55
from loguru import logger
6+
from PIL import Image
67
import anthropic
78

8-
from openadapt import cache
9+
from openadapt import cache, utils
910
from openadapt.config import config
1011

1112
MAX_TOKENS = 4096
@@ -18,7 +19,7 @@
1819
def create_payload(
1920
prompt: str,
2021
system_prompt: str | None = None,
21-
base64_images: list[tuple[str, str]] | None = None,
22+
images: list[Image.Image] | None = None,
2223
model: str = MODEL_NAME,
2324
max_tokens: int | None = None,
2425
) -> dict:
@@ -33,10 +34,12 @@ def create_payload(
3334
max_tokens = MAX_TOKENS
3435

3536
# Add base64 encoded images to the user message content
36-
if base64_images:
37-
for image_data in base64_images:
37+
if images:
38+
for image in images:
39+
image_base64 = utils.image2utf8(image)
3840
# Extract media type and base64 data
39-
media_type, base64_str = image_data.split(";base64,", 1)
41+
# TODO: don't add it to begin with
42+
media_type, image_base64_data = image_base64.split(";base64,", 1)
4043
media_type = media_type.split(":")[-1] # Remove 'data:' prefix
4144

4245
user_message_content.append(
@@ -45,7 +48,7 @@ def create_payload(
4548
"source": {
4649
"type": "base64",
4750
"media_type": media_type,
48-
"data": base64_str,
51+
"data": image_base64_data,
4952
},
5053
}
5154
)
@@ -80,19 +83,22 @@ def create_payload(
8083
return payload
8184

8285

83-
client = anthropic.Anthropic(api_key=config.ANTHROPIC_API_KEY)
84-
85-
8686
@cache.cache()
87-
def get_completion(payload: dict) -> str:
87+
def get_completion(
88+
payload: dict, dev_mode: bool = False, api_key: str = config.ANTHROPIC_API_KEY
89+
) -> str:
8890
"""Sends a request to the Anthropic API and returns the response."""
91+
client = anthropic.Anthropic(api_key=api_key)
8992
try:
9093
response = client.messages.create(**payload)
9194
except Exception as exc:
9295
logger.exception(exc)
93-
import ipdb
96+
if dev_mode:
97+
import ipdb
9498

95-
ipdb.set_trace()
99+
ipdb.set_trace()
100+
else:
101+
raise
96102
"""
97103
Message(
98104
id='msg_01L55ai2A9q92687mmjMSch3',
@@ -125,19 +131,17 @@ def get_completion(payload: dict) -> str:
125131
def prompt(
126132
prompt: str,
127133
system_prompt: str | None = None,
128-
base64_images: list[str] | None = None,
134+
images: list[Image.Image] | None = None,
129135
max_tokens: int | None = None,
130136
) -> str:
131137
"""Public method to get a response from the Anthropic API with image support."""
132-
if len(base64_images) > MAX_IMAGES:
138+
if len(images) > MAX_IMAGES:
133139
# XXX TODO handle this
134-
raise Exception(
135-
f"{len(base64_images)=} > {MAX_IMAGES=}. Use a different adapter."
136-
)
140+
raise Exception(f"{len(images)=} > {MAX_IMAGES=}. Use a different adapter.")
137141
payload = create_payload(
138142
prompt,
139143
system_prompt,
140-
base64_images,
144+
images,
141145
max_tokens=max_tokens,
142146
)
143147
# pprint(f"payload=\n{payload}") # Log payload for debugging

openadapt/adapters/google.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
See https://ai.google.dev/tutorials/python_quickstart for documentation.
44
"""
55

6-
from pprint import pprint
6+
from pprint import pformat
77

8+
from loguru import logger
89
from PIL import Image
910
import fire
1011
import google.generativeai as genai
1112

12-
from openadapt import cache, utils
13+
from openadapt import cache
1314
from openadapt.config import config
1415

1516
MAX_TOKENS = 2**20 # 1048576
@@ -28,46 +29,41 @@
2829
def prompt(
2930
prompt: str,
3031
system_prompt: str | None = None,
31-
base64_images: list[str] | None = None,
32+
images: list[Image.Image] | None = None,
3233
# max_tokens: int | None = None,
3334
model_name: str = MODEL_NAME,
35+
timeout: int = 10,
3436
) -> str:
3537
"""Public method to get a response from the Google API with image support."""
3638
full_prompt = "\n\n###\n\n".join([s for s in (system_prompt, prompt) if s])
3739
# HACK
3840
full_prompt += "\nWhen responding in JSON, you MUST use double quotes around keys."
3941

40-
# TODO: modify API across all adapters to accept PIL.Image
41-
images = (
42-
[utils.utf82image(base64_image) for base64_image in base64_images]
43-
if base64_images
44-
else []
45-
)
46-
4742
genai.configure(api_key=config.GOOGLE_API_KEY)
4843
model = genai.GenerativeModel(model_name)
49-
response = model.generate_content([full_prompt] + images)
44+
response = model.generate_content(
45+
[full_prompt] + images, request_options={"timeout": timeout}
46+
)
5047
response.resolve()
51-
pprint(f"response=\n{response}") # Log response for debugging
48+
logger.info(f"response=\n{pformat(response)}")
5249
return response.text
5350

5451

5552
def main(text: str, image_path: str | None = None) -> None:
5653
"""Prompt Google Gemini with text and a path to an image."""
5754
if image_path:
58-
with Image.open(image_path) as img:
59-
# Convert image to RGB if it's RGBA (to remove alpha channel)
60-
if img.mode in ("RGBA", "LA") or (
61-
img.mode == "P" and "transparency" in img.info
62-
):
63-
img = img.convert("RGB")
64-
base64_image = utils.image2utf8(img)
55+
image = Image.open(image_path)
56+
# Convert image to RGB if it's RGBA (to remove alpha channel)
57+
if image.mode in ("RGBA", "LA") or (
58+
image.mode == "P" and "transparency" in image.info
59+
):
60+
image = image.convert("RGB")
6561
else:
66-
base64_image = None
62+
image = None
6763

68-
base64_images = [base64_image] if base64_image else None
69-
output = prompt(text, base64_images=base64_images)
70-
print(output)
64+
images = [image] if image else None
65+
output = prompt(text, images=images)
66+
logger.info(output)
7167

7268

7369
if __name__ == "__main__":

openadapt/adapters/openai.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,23 @@
33
https://platform.openai.com/docs/guides/vision
44
"""
55

6+
from copy import deepcopy
67
from pprint import pformat
8+
from typing import Any
79

810
from loguru import logger
11+
from PIL import Image
912
import requests
1013

11-
from openadapt import cache
14+
from openadapt import cache, utils
1215
from openadapt.config import config
1316

1417
MODEL_NAME = [
1518
"gpt-4-vision-preview",
1619
"gpt-4-turbo-2024-04-09",
20+
"gpt-4o",
1721
][-1]
22+
# TODO XXX: per model
1823
MAX_TOKENS = 4096
1924
# TODO XXX undocumented
2025
MAX_IMAGES = None
@@ -23,7 +28,7 @@
2328
def create_payload(
2429
prompt: str,
2530
system_prompt: str | None = None,
26-
base64_images: list[str] | None = None,
31+
images: list[Image.Image] | None = None,
2732
model: str = MODEL_NAME,
2833
detail: str = "high", # "low" or "high"
2934
max_tokens: int | None = None,
@@ -33,7 +38,7 @@ def create_payload(
3338
Args:
3439
prompt: the prompt
3540
system_prompt: the system prompt
36-
base64_images: list of base64 encoded images
41+
images: list of images
3742
model: name of OpenAI model
3843
detail: detail level of images, "low" or "high"
3944
max_tokens: maximum number of tokens
@@ -59,8 +64,9 @@ def create_payload(
5964
},
6065
]
6166

62-
base64_images = base64_images or []
63-
for base64_image in base64_images:
67+
images = images or []
68+
for image in images:
69+
base64_image = utils.image2utf8(image)
6470
messages[0]["content"].append(
6571
{
6672
"type": "image_url",
@@ -94,18 +100,22 @@ def create_payload(
94100

95101

96102
@cache.cache()
97-
def get_response(payload: dict) -> requests.Response:
103+
def get_response(
104+
payload: dict,
105+
api_key: str = config.OPENAI_API_KEY,
106+
) -> requests.Response:
98107
"""Sends a request to the OpenAI API and returns the response.
99108
100109
Args:
101110
payload: dictionary returned by create_payload
111+
api_key (str): api key
102112
103113
Returns:
104114
response from OpenAI API
105115
"""
106116
headers = {
107117
"Content-Type": "application/json",
108-
"Authorization": f"Bearer {config.OPENAI_API_KEY}",
118+
"Authorization": f"Bearer {api_key}",
109119
}
110120
response = requests.post(
111121
"https://api.openai.com/v1/chat/completions",
@@ -115,14 +125,15 @@ def get_response(payload: dict) -> requests.Response:
115125
return response
116126

117127

118-
def get_completion(payload: dict) -> str:
128+
def get_completion(payload: dict, dev_mode: bool = False) -> str:
119129
"""Sends a request to the OpenAI API and returns the first message.
120130
121131
Args:
122-
pyalod: dictionary returned by create_payload
132+
payload (dict): dictionary returned by create_payload
133+
dev_mode (bool): whether to launch a debugger on error
123134
124135
Returns:
125-
string containing the first message from the response
136+
(str) first message from the response
126137
"""
127138
response = get_response(payload)
128139
result = response.json()
@@ -133,22 +144,37 @@ def get_completion(payload: dict) -> str:
133144
# TODO: fail after maximum number of attempts
134145
if "retry your request" in message:
135146
return get_completion(payload)
136-
else:
147+
elif dev_mode:
137148
import ipdb
138149

139150
ipdb.set_trace()
140151
# TODO: handle more errors
152+
else:
153+
raise ValueError(result["error"]["message"])
141154
choices = result["choices"]
142155
choice = choices[0]
143156
message = choice["message"]
144157
content = message["content"]
145158
return content
146159

147160

161+
def log_payload(payload: dict[Any, Any]) -> None:
162+
"""Logs a payload after removing base-64 encoded values recursively."""
163+
# TODO: detect base64 encoded strings dynamically
164+
# messages["content"][{"image_url": ...
165+
# payload["messages"][1]["content"][9]["image_url"]
166+
payload_copy = deepcopy(payload)
167+
for message in payload_copy["messages"]:
168+
for content in message["content"]:
169+
if "image_url" in content:
170+
content["image_url"]["url"] = "[REDACTED]"
171+
logger.info(f"payload=\n{pformat(payload_copy)}")
172+
173+
148174
def prompt(
149175
prompt: str,
150176
system_prompt: str | None = None,
151-
base64_images: list[str] | None = None,
177+
images: list[Image.Image] | None = None,
152178
max_tokens: int | None = None,
153179
detail: str = "high",
154180
) -> str:
@@ -157,7 +183,7 @@ def prompt(
157183
Args:
158184
prompt: the prompt
159185
system_prompt: the system prompt
160-
base64_images: list of base64 encoded images
186+
images: list of images
161187
model: name of OpenAI model
162188
detail: detail level of images, "low" or "high"
163189
max_tokens: maximum number of tokens
@@ -168,11 +194,11 @@ def prompt(
168194
payload = create_payload(
169195
prompt,
170196
system_prompt,
171-
base64_images,
197+
images,
172198
max_tokens=max_tokens,
173199
detail=detail,
174200
)
175-
logger.info(f"payload=\n{pformat(payload)}")
201+
log_payload(payload)
176202
result = get_completion(payload)
177203
logger.info(f"result=\n{pformat(result)}")
178204
return result

openadapt/capture/_macos.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010

1111
from Foundation import NSURL, NSObject # type: ignore # noqa
12+
from loguru import logger
1213
from Quartz import CGMainDisplayID # type: ignore # noqa
1314
import AVFoundation as AVF # type: ignore # noqa
1415
import objc # type: ignore # noqa
@@ -56,6 +57,7 @@ def start(self, audio: bool = False, camera: bool = False) -> None:
5657
datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".mov",
5758
)
5859
)
60+
logger.info(f"{self.file_url=}")
5961
if audio and self.session.canAddInput_(self.audio_input[0]):
6062
self.session.addInput_(self.audio_input[0])
6163

openadapt/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
).absolute()
2626

2727
ROOT_DIR_PATH = get_root_dir_path()
28+
PARENT_DIR_PATH = ROOT_DIR_PATH.parent
2829
DATA_DIR_PATH = (ROOT_DIR_PATH / "data").absolute()
2930
CONFIG_FILE_PATH = (DATA_DIR_PATH / "config.json").absolute()
3031
RECORDING_DIR_PATH = (DATA_DIR_PATH / "recordings").absolute()
@@ -136,7 +137,7 @@ class SegmentationAdapter(str, Enum):
136137
OPENAI_MODEL_NAME: str = "gpt-3.5-turbo"
137138

138139
# Record and replay
139-
RECORD_WINDOW_DATA: bool = False
140+
RECORD_WINDOW_DATA: bool = True
140141
RECORD_READ_ACTIVE_ELEMENT_STATE: bool = False
141142
RECORD_VIDEO: bool
142143
RECORD_AUDIO: bool
@@ -407,7 +408,7 @@ def print_config() -> None:
407408
if is_running_from_executable():
408409
is_reporting_branch = True
409410
else:
410-
active_branch_name = git.Repo(ROOT_DIR_PATH.parent).active_branch.name
411+
active_branch_name = git.Repo(PARENT_DIR_PATH).active_branch.name
411412
logger.info(f"{active_branch_name=}")
412413
is_reporting_branch = (
413414
active_branch_name == config.ERROR_REPORTING_BRANCH

0 commit comments

Comments
 (0)