Skip to content

Commit d92e0bf

Browse files
Merge pull request #262 from ServiceNow/deep_debug
Deep debug
2 parents a67b22b + 6a4c808 commit d92e0bf

File tree

5 files changed

+148
-16
lines changed

5 files changed

+148
-16
lines changed

main_workarena_debug.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
Note: This script is a convenience script to launch experiments instead of using
3+
the command line.
4+
5+
Copy this script and modify at will, but don't push your changes to the
6+
repository.
7+
"""
8+
9+
import logging
10+
from copy import deepcopy
11+
12+
import bgym
13+
14+
from agentlab.agents.tool_use_agent.tool_use_agent import (
15+
DEFAULT_PROMPT_CONFIG,
16+
GPT_4_1,
17+
ToolUseAgentArgs,
18+
)
19+
from agentlab.experiments.study import Study
20+
21+
logging.getLogger().setLevel(logging.INFO)
22+
23+
config = deepcopy(DEFAULT_PROMPT_CONFIG)
24+
# config.keep_last_n_obs = 1
25+
config.obs.use_som = True
26+
27+
28+
agent_configs = [
29+
ToolUseAgentArgs(
30+
model_args=GPT_4_1,
31+
config=config,
32+
),
33+
# ToolUseAgentArgs(
34+
# model_args=GPT_4_1,
35+
# config=config,
36+
# ),
37+
]
38+
39+
for agent_config in agent_configs:
40+
agent_config.config.action_subsets = ("workarena",) # use the workarena action set
41+
42+
43+
# ## select the benchmark to run on
44+
# benchmark = "miniwob_tiny_test"
45+
benchmark = "workarena_l1"
46+
47+
48+
benchmark = bgym.DEFAULT_BENCHMARKS[benchmark](n_repeats=4) # type: bgym.Benchmark
49+
benchmark = benchmark.subset_from_glob("task_name", "*create*")
50+
51+
# for env_args in benchmark.env_args_list:
52+
# print(env_args.task_name)
53+
# env_args.max_steps = 15
54+
55+
relaunch = False
56+
57+
## Number of parallel jobs
58+
n_jobs = 10 # Make sure to use 1 job when debugging in VSCode
59+
parallel_backend = "ray"
60+
# parallel_backend = "sequential" # activate sequential backend for debugging in VSCode
61+
62+
if __name__ == "__main__": # necessary for dask backend
63+
64+
if relaunch:
65+
# relaunch an existing study
66+
study = Study.load_most_recent(contains=None)
67+
study.find_incomplete(include_errors=True)
68+
69+
else:
70+
study = Study(agent_configs, benchmark, logging_level_stdout=logging.WARNING)
71+
72+
study.run(
73+
n_jobs=n_jobs,
74+
parallel_backend=parallel_backend, # "ray", "joblib" or "sequential"
75+
strict_reproducibility=False,
76+
n_relaunch=3,
77+
)

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,10 @@ class Goal(Block):
127127

128128
goal_as_system_msg: bool = True
129129

130-
def apply(self, llm, discussion: StructuredDiscussion, obs: dict) -> dict:
131-
system_message = llm.msg.system().add_text(SYS_MSG)
130+
def apply(
131+
self, llm, discussion: StructuredDiscussion, obs: dict, sys_msg: str = SYS_MSG
132+
) -> dict:
133+
system_message = llm.msg.system().add_text(sys_msg)
132134
discussion.append(system_message)
133135

134136
if self.goal_as_system_msg:
@@ -441,7 +443,13 @@ def get_action(self, obs: Any) -> float:
441443
self.llm.reset_stats()
442444
if not self.discussion.is_goal_set():
443445
self.discussion.new_group("goal")
444-
self.config.goal.apply(self.llm, self.discussion, obs)
446+
447+
if self.config.multiaction:
448+
sys_msg = SYS_MSG + "\nYou can take multiple actions in a single step, if needed."
449+
else:
450+
sys_msg = SYS_MSG + "\nYou can only take one action at a time."
451+
self.config.goal.apply(self.llm, self.discussion, obs, sys_msg)
452+
445453
self.config.summarizer.apply_init(self.llm, self.discussion)
446454
self.config.general_hints.apply(self.llm, self.discussion)
447455
self.task_hint.apply(self.llm, self.discussion, self.task_name)
@@ -489,7 +497,7 @@ def get_action(self, obs: Any) -> float:
489497
return action, agent_info
490498

491499

492-
OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs(
500+
GPT_4_1 = OpenAIResponseModelArgs(
493501
model_name="gpt-4.1",
494502
max_total_tokens=200_000,
495503
max_input_tokens=200_000,

src/agentlab/analyze/agent_xray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,7 @@ def get_directory_contents(results_dir: Path):
11641164
most_recent_summary = max(summary_files, key=os.path.getctime)
11651165
summary_df = pd.read_csv(most_recent_summary)
11661166

1167-
if len(summary_df) == 0 or summary_df["avg_reward"].isna().all():
1167+
if len(summary_df) == 0:
11681168
continue # skip if all avg_reward are NaN
11691169

11701170
# get row with max avg_reward

src/agentlab/analyze/overlay_utils.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import ast
22
import inspect
3+
import math
34
from dataclasses import dataclass
45
from typing import Any, Union
56

67
import matplotlib.pyplot as plt
8+
import PIL
79
from browsergym.core.action.highlevel import ACTION_SUBSETS
810
from PIL import Image, ImageDraw
911

@@ -289,17 +291,54 @@ def overlay_rectangle(
289291
bbox: tuple[float, float, float, float],
290292
color: Union[str, tuple[int, int, int]] = "red",
291293
width: int = 1,
294+
dashed: bool = True,
292295
) -> Image.Image:
293296
draw = ImageDraw.Draw(img)
294297

295298
x, y, w, h = bbox
296299

297-
# Draw rectangle outline
298-
draw.rectangle([x, y, x + w, y + h], outline=color, width=width)
300+
if dashed:
301+
# Draw dashed rectangle
302+
print("Drawing dashed rectangle")
303+
linedashed(draw, x, y, x + w, y, color, width)
304+
linedashed(draw, x + w, y, x + w, y + h, color, width)
305+
linedashed(draw, x + w, y + h, x, y + h, color, width)
306+
linedashed(draw, x, y + h, x, y, color, width)
307+
else:
308+
draw.rectangle([x, y, x + w, y + h], outline=color, width=width)
299309

300310
return img
301311

302312

313+
# Adapted from https://stackoverflow.com/questions/51908563/dotted-or-dashed-line-with-python-pillow/58885306#58885306
314+
def linedashed(
315+
draw: PIL.ImageDraw.Draw, x0, y0, x1, y1, fill, width, dash_length=4, nodash_length=8
316+
):
317+
line_dx = x1 - x0 # delta x (can be negative)
318+
line_dy = y1 - y0 # delta y (can be negative)
319+
line_length = math.hypot(line_dx, line_dy) # line length (positive)
320+
if line_length == 0:
321+
return # Avoid division by zero in case the line length is 0
322+
pixel_dx = line_dx / line_length # x add for 1px line length
323+
pixel_dy = line_dy / line_length # y add for 1px line length
324+
dash_start = 0
325+
while dash_start < line_length:
326+
dash_end = dash_start + dash_length
327+
if dash_end > line_length:
328+
dash_end = line_length
329+
draw.line(
330+
(
331+
round(x0 + pixel_dx * dash_start),
332+
round(y0 + pixel_dy * dash_start),
333+
round(x0 + pixel_dx * dash_end),
334+
round(y0 + pixel_dy * dash_end),
335+
),
336+
fill=fill,
337+
width=width,
338+
)
339+
dash_start += dash_length + nodash_length
340+
341+
303342
def annotate_action(
304343
img: Image.Image, action_string: str, properties: dict[str, tuple], colormap: str = "tab10"
305344
) -> str:

src/agentlab/llm/response_api.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,6 @@ def __init__(
313313
**kwargs,
314314
):
315315
self.tools = kwargs.pop("tools", None)
316-
self.tool_choice = kwargs.pop("tool_choice", None)
317316
super().__init__(
318317
model_name=model_name,
319318
api_key=api_key,
@@ -324,7 +323,9 @@ def __init__(
324323
)
325324
self.client = OpenAI(api_key=api_key)
326325

327-
def _call_api(self, messages: list[Any | MessageBuilder], **kwargs) -> dict:
326+
def _call_api(
327+
self, messages: list[Any | MessageBuilder], tool_choice: str = "auto", **kwargs
328+
) -> dict:
328329
input = []
329330
for msg in messages:
330331
input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg])
@@ -339,8 +340,10 @@ def _call_api(self, messages: list[Any | MessageBuilder], **kwargs) -> dict:
339340

340341
if self.tools is not None:
341342
api_params["tools"] = self.tools
342-
if self.tool_choice is not None:
343-
api_params["tool_choice"] = self.tool_choice
343+
if tool_choice in ("any", "required"):
344+
tool_choice = "required"
345+
346+
api_params["tool_choice"] = tool_choice
344347

345348
# api_params |= kwargs # Merge any additional parameters passed
346349
response = call_openai_api_with_retries(
@@ -388,7 +391,6 @@ def __init__(
388391
):
389392

390393
self.tools = self.format_tools_for_chat_completion(kwargs.pop("tools", None))
391-
self.tool_choice = kwargs.pop("tool_choice", None)
392394

393395
super().__init__(
394396
model_name=model_name,
@@ -403,7 +405,9 @@ def __init__(
403405
**client_args
404406
) # Ensures client_args is a dict or defaults to an empty dict
405407

406-
def _call_api(self, messages: list[dict | MessageBuilder]) -> openai.types.chat.ChatCompletion:
408+
def _call_api(
409+
self, messages: list[dict | MessageBuilder], tool_choice: str = "auto"
410+
) -> openai.types.chat.ChatCompletion:
407411
input = []
408412
for msg in messages:
409413
input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg])
@@ -416,8 +420,10 @@ def _call_api(self, messages: list[dict | MessageBuilder]) -> openai.types.chat.
416420
}
417421
if self.tools is not None:
418422
api_params["tools"] = self.tools
419-
if self.tool_choice is not None:
420-
api_params["tool_choice"] = self.tool_choice
423+
424+
if tool_choice in ("any", "required"):
425+
tool_choice = "required"
426+
api_params["tool_choice"] = tool_choice
421427

422428
response = call_openai_api_with_retries(self.client.chat.completions.create, api_params)
423429

@@ -517,7 +523,6 @@ def __init__(
517523
**kwargs,
518524
):
519525
self.tools = kwargs.pop("tools", None)
520-
self.tool_choice = kwargs.pop("tool_choice", None)
521526

522527
super().__init__(
523528
model_name=model_name,
@@ -543,6 +548,9 @@ def _call_api(
543548
temp = self.apply_cache_breakpoints(msg, temp)
544549
input.extend(temp)
545550

551+
if tool_choice in ("any", "required"):
552+
tool_choice = "any" # Claude API expects "any" and gpt expects "required"
553+
546554
api_params: Dict[str, Any] = {
547555
"model": self.model_name,
548556
"messages": input,

0 commit comments

Comments
 (0)