Skip to content

Commit 7ef115a

Browse files
authored
feat(SegmentReplayStrategy, drivers): add strategies.replay; refactor adapters -> drivers + adapters (#714)
* implemented * add get_active_window_data parameter include_window_data; fix ActionEvent.from_dict to handle multiple separators; add test_models.py * update get_default_prompt_adapter * add TODO * tests.openadapt.adapters -> drivers * utils.get_marked_image, .extract_code_block; .strip_backticks * working segment.py (misses final click in calculator task) * include_replay_instructions; dev_mode=False * fix test_openai.py: ValueError -> Exception * replay.py --record -> --capture * black/flake8 * remove import * INCLUDE_CURRENT_SCREENSHOT; handle mouse events without x/y * add models.Replay; print_config in replay.py
1 parent c674678 commit 7ef115a

File tree

22 files changed

+740
-100
lines changed

22 files changed

+740
-100
lines changed

openadapt/adapters/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,20 @@
44

55
from openadapt.config import config
66

7-
from . import anthropic, google, openai, replicate, som, ultralytics
7+
from . import prompt, replicate, som, ultralytics
88

99

10+
# TODO: remove
1011
def get_default_prompt_adapter() -> ModuleType:
1112
"""Returns the default prompt adapter module.
1213
1314
Returns:
1415
The module corresponding to the default prompt adapter.
1516
"""
16-
return {
17-
"openai": openai,
18-
"anthropic": anthropic,
19-
"google": google,
20-
}[config.DEFAULT_ADAPTER]
17+
return prompt
2118

2219

20+
# TODO: refactor to follow adapters.prompt
2321
def get_default_segmentation_adapter() -> ModuleType:
2422
"""Returns the default image segmentation adapter module.
2523

openadapt/adapters/prompt.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Adapter for prompting foundation models."""
2+
3+
from loguru import logger
4+
from typing import Type
5+
from PIL import Image
6+
7+
8+
from openadapt.drivers import anthropic, google, openai
9+
10+
11+
# Define a list of drivers in the order they should be tried
12+
DRIVER_ORDER: list[Type] = [openai, google, anthropic]
13+
14+
15+
def prompt(
16+
text: str,
17+
images: list[Image.Image] | None = None,
18+
system_prompt: str | None = None,
19+
) -> str:
20+
"""Attempt to fetch a prompt completion from various services in order of priority.
21+
22+
Args:
23+
text: The main text prompt.
24+
images: list of images to include in the prompt.
25+
system_prompt: An optional system-level prompt.
26+
27+
Returns:
28+
The result from the first successful driver.
29+
"""
30+
text = text.strip()
31+
for driver in DRIVER_ORDER:
32+
try:
33+
logger.info(f"Trying driver: {driver.__name__}")
34+
return driver.prompt(text, images=images, system_prompt=system_prompt)
35+
except Exception as e:
36+
logger.exception(e)
37+
logger.error(f"Driver {driver.__name__} failed with error: {e}")
38+
import ipdb
39+
40+
ipdb.set_trace()
41+
continue
42+
raise Exception("All drivers failed to provide a response")
43+
44+
45+
if __name__ == "__main__":
46+
# This could be extended to use command-line arguments or other input methods
47+
print(prompt("Describe the solar system."))

openadapt/adapters/ultralytics.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,26 @@ def do_fastsam(
7777
retina_masks: bool = True,
7878
imgsz: int | tuple[int, int] | None = 1024,
7979
# threshold below which boxes will be filtered out
80-
conf: float = 0.4,
80+
min_confidence_threshold: float = 0.4,
8181
# discards all overlapping boxes with IoU > iou_threshold
82-
iou: float = 0.9,
82+
max_iou_threshold: float = 0.9,
8383
) -> Image:
84+
"""Get segmented image via FastSAM.
85+
86+
For usage of thresholds see:
87+
github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf
88+
/ultralytics/utils/ops.py#L164
89+
90+
Args:
91+
TODO
92+
min_confidence_threshold (float, optional): The minimum confidence score
93+
that a detection must meet or exceed to be considered valid. Detections
94+
below this threshold will not be marked. Defaults to 0.00.
95+
max_iou_threshold (float, optional): The maximum allowed Intersection over
96+
Union (IoU) value for overlapping detections. Detections that exceed this
97+
IoU threshold are considered for suppression, keeping only the
98+
detection with the highest confidence. Defaults to 0.05.
99+
"""
84100
model = FastSAM(model_name)
85101

86102
imgsz = imgsz or image.size
@@ -91,8 +107,8 @@ def do_fastsam(
91107
device=device,
92108
retina_masks=retina_masks,
93109
imgsz=imgsz,
94-
conf=conf,
95-
iou=iou,
110+
conf=min_confidence_threshold,
111+
iou=max_iou_threshold,
96112
)
97113

98114
# Prepare a Prompt Process object
Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,12 @@ def get_response(
123123
headers=headers,
124124
json=payload,
125125
)
126-
return response
126+
result = response.json()
127+
if "error" in result:
128+
error = result["error"]
129+
message = error["message"]
130+
raise Exception(message)
131+
return result
127132

128133

129134
def get_completion(payload: dict, dev_mode: bool = False) -> str:
@@ -136,23 +141,19 @@ def get_completion(payload: dict, dev_mode: bool = False) -> str:
136141
Returns:
137142
(str) first message from the response
138143
"""
139-
response = get_response(payload)
140-
response.raise_for_status()
141-
result = response.json()
142-
logger.info(f"result=\n{pformat(result)}")
143-
if "error" in result:
144-
error = result["error"]
145-
message = error["message"]
146-
# TODO: fail after maximum number of attempts
147-
if "retry your request" in message:
144+
try:
145+
result = get_response(payload)
146+
except Exception as exc:
147+
if "retry your request" in str(exc):
148148
return get_completion(payload)
149149
elif dev_mode:
150150
import ipdb
151151

152152
ipdb.set_trace()
153153
# TODO: handle more errors
154154
else:
155-
raise ValueError(result["error"]["message"])
155+
raise exc
156+
logger.info(f"result=\n{pformat(result)}")
156157
choices = result["choices"]
157158
choice = choices[0]
158159
message = choice["message"]

openadapt/models.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,11 @@ def from_dict(
388388
suffix_len = len(name_suffix)
389389

390390
key_names = utils.split_by_separators(
391-
action_dict["text"][prefix_len:-suffix_len],
391+
action_dict.get("text", "")[prefix_len:-suffix_len],
392392
key_seps,
393393
)
394394
canonical_key_names = utils.split_by_separators(
395-
action_dict["canonical_text"][prefix_len:-suffix_len],
395+
action_dict.get("canonical_text", "")[prefix_len:-suffix_len],
396396
key_seps,
397397
)
398398
logger.info(f"{key_names=}")
@@ -920,6 +920,18 @@ def asdict(self) -> dict:
920920
}
921921

922922

923+
class Replay(db.Base):
924+
"""Class representing a replay in the database."""
925+
926+
__tablename__ = "replay"
927+
928+
id = sa.Column(sa.Integer, primary_key=True)
929+
timestamp = sa.Column(ForceFloat)
930+
strategy_name = sa.Column(sa.String)
931+
strategy_args = sa.Column(sa.JSON)
932+
git_hash = sa.Column(sa.String)
933+
934+
923935
def copy_sa_instance(sa_instance: db.Base, **kwargs: dict) -> db.Base:
924936
"""Copy a SQLAlchemy instance.
925937

openadapt/playback.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ def play_mouse_event(event: ActionEvent, mouse_controller: mouse.Controller) ->
2727
pressed = event.mouse_pressed
2828
logger.debug(f"{name=} {x=} {y=} {dx=} {dy=} {button_name=} {pressed=}")
2929

30-
mouse_controller.position = (x, y)
30+
if all([val is not None for val in (x, y)]):
31+
mouse_controller.position = (x, y)
32+
else:
33+
logger.warning(f"{x=} {y=}")
3134
if name == "move":
3235
pass
3336
elif name == "click":

openadapt/plotting.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import matplotlib.pyplot as plt
1414
import numpy as np
1515

16-
from openadapt import common, models, utils
16+
from openadapt import common, contrib, models, utils
1717
from openadapt.config import PERFORMANCE_PLOTS_DIR_PATH, config
1818
from openadapt.models import ActionEvent
1919

@@ -764,3 +764,57 @@ def plot_segments(
764764
plt.imshow(image)
765765
plt.axis("off")
766766
plt.show()
767+
768+
769+
def get_marked_image(
770+
original_image: Image.Image,
771+
masks: list[np.ndarray],
772+
include_masks: bool = True,
773+
include_marks: bool = True,
774+
) -> Image.Image:
775+
"""Get a Set-of-Mark image using the original SoM visualizer.
776+
777+
Args:
778+
original_image (Image.Image): The original PIL image.
779+
masks (list[np.ndarray]): A list of masks representing segments in the
780+
original image.
781+
include_masks (bool, optional): If True, masks will be included in the
782+
output visualizations. Defaults to True.
783+
include_marks (bool, optional): If True, marks will be included in the
784+
output visualizations. Defaults to True.
785+
786+
Returns:
787+
Image.Image: The marked image, where marks and/or masks are applied based on
788+
the specified confidence and IoU thresholds and the include flags.
789+
"""
790+
image_arr = np.asarray(original_image)
791+
792+
# The rest of this function is copied from
793+
# github.com/microsoft/SoM/blob/main/task_adapter/sam/tasks/inference_sam_m2m_auto.py
794+
795+
# metadata = MetadataCatalog.get('coco_2017_train_panoptic')
796+
metadata = None
797+
visual = contrib.som.visualizer.Visualizer(image_arr, metadata=metadata)
798+
mask_map = np.zeros(image_arr.shape, dtype=np.uint8)
799+
label_mode = "1"
800+
alpha = 0.1
801+
anno_mode = []
802+
if include_masks:
803+
anno_mode.append("Mask")
804+
if include_marks:
805+
anno_mode.append("Mark")
806+
for i, mask in enumerate(masks):
807+
label = i + 1
808+
demo = visual.draw_binary_mask_with_number(
809+
mask,
810+
text=str(label),
811+
label_mode=label_mode,
812+
alpha=alpha,
813+
anno_mode=anno_mode,
814+
)
815+
mask_map[mask == 1] = label
816+
817+
im = demo.get_image()
818+
marked_image = Image.fromarray(im)
819+
820+
return marked_image
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Consider the actions in the recording and states of the active window immediately
2+
before each action was taken:
3+
4+
```json
5+
{{ action_windows }}
6+
```
7+
8+
Consider the attached screenshots taken immediately before each action. The order of
9+
the screenshots matches the order of the actions above.
10+
11+
Provide a detailed natural language description of everything that happened
12+
in this recording. This description will be embedded in the context for a future prompt
13+
to replay the recording (subject to proposed modifications in natural language) on a
14+
live system, so make sure to include everything you will need to know.
15+
16+
My career depends on this. Lives are at stake.

0 commit comments

Comments
 (0)