Skip to content

Commit 5118b39

Browse files
feat: add_centroid in WindowEvent.to_prompt_dict (#840)
* add centroid in window_dict output * black --------- Co-authored-by: Richard Abrich <[email protected]> Co-authored-by: Richard Abrich <[email protected]>
1 parent b288c07 commit 5118b39

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

openadapt/models.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,12 @@ def scrub(self, scrubber: ScrubbingProvider | TextScrubbingMixin) -> None:
533533
if self.state is not None:
534534
self.state = scrubber.scrub_dict(self.state)
535535

536-
def to_prompt_dict(self, include_data: bool = True) -> dict[str, Any]:
536+
def to_prompt_dict(
537+
self,
538+
include_data: bool = True,
539+
add_centroid: bool = True,
540+
remove_bbox: bool = False,
541+
) -> dict[str, Any]:
537542
"""Convert into a dict, excluding properties not necessary for prompting.
538543
539544
Args:
@@ -553,6 +558,29 @@ def to_prompt_dict(self, include_data: bool = True) -> dict[str, Any]:
553558
# and not isinstance(getattr(models.WindowEvent, key), property)
554559
}
555560
)
561+
562+
if add_centroid:
563+
left = window_dict["left"]
564+
top = window_dict["top"]
565+
width = window_dict["width"]
566+
height = window_dict["height"]
567+
568+
# Compute the centroid of the bounding box
569+
centroid_x = left + width / 2
570+
centroid_y = top + height / 2
571+
572+
# Add centroid in the prompt dict { "centroid": }
573+
window_dict["centroid"] = {
574+
"x": centroid_x,
575+
"y": centroid_y,
576+
}
577+
578+
if remove_bbox:
579+
window_dict.pop("left")
580+
window_dict.pop("top")
581+
window_dict.pop("width")
582+
window_dict.pop("height")
583+
556584
if "state" in window_dict:
557585
if include_data:
558586
key_suffixes = [
@@ -574,14 +602,13 @@ def to_prompt_dict(self, include_data: bool = True) -> dict[str, Any]:
574602
# from pprint import pformat
575603
# logger.info(f"window_dict=\n{pformat(window_dict)}")
576604
# import ipdb; ipdb.set_trace()
577-
if "state" in window_dict:
578-
window_state = window_dict["state"]
579-
window_state["data"] = utils.clean_dict(
580-
utils.filter_keys(
581-
window_state["data"],
582-
key_suffixes,
583-
)
605+
window_state = window_dict["state"]
606+
window_state["data"] = utils.clean_dict(
607+
utils.filter_keys(
608+
window_state["data"],
609+
key_suffixes,
584610
)
611+
)
585612
else:
586613
window_dict["state"].pop("data")
587614
window_dict["state"].pop("meta")

0 commit comments

Comments
 (0)