Skip to content

Commit c32400f

Browse files
Merge pull request #258 from ServiceNow/new_experiments
New experiments
2 parents 8523f83 + a0153be commit c32400f

File tree

14 files changed

+443
-40
lines changed

14 files changed

+443
-40
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,4 @@ miniwob-plusplus/
174174
debugging_results/
175175

176176
# working files
177-
main_miniwob_debug.py
178-
main_workarena_debug.py
177+
experiments/*

src/agentlab/agents/generic_agent/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,23 @@
99
from .agent_configs import (
1010
AGENT_3_5,
1111
AGENT_8B,
12+
AGENT_37_SONNET,
13+
AGENT_CLAUDE_SONNET_35,
14+
AGENT_CLAUDE_SONNET_35_VISION,
1215
AGENT_CUSTOM,
13-
AGENT_LLAMA4_17B_INSTRUCT,
1416
AGENT_LLAMA3_70B,
17+
AGENT_LLAMA4_17B_INSTRUCT,
1518
AGENT_LLAMA31_70B,
19+
CHAT_MODEL_ARGS_DICT,
1620
RANDOM_SEARCH_AGENT,
1721
AGENT_4o,
1822
AGENT_4o_MINI,
19-
AGENT_CLAUDE_SONNET_35,
20-
AGENT_37_SONNET,
21-
AGENT_CLAUDE_SONNET_35_VISION,
22-
AGENT_4o_VISION,
2323
AGENT_4o_MINI_VISION,
24-
AGENT_o3_MINI,
24+
AGENT_4o_VISION,
2525
AGENT_o1_MINI,
26+
AGENT_o3_MINI,
27+
FLAGS_GPT_4o,
28+
GenericAgentArgs,
2629
)
2730

2831
__all__ = [
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import sys
22

3+
from agentlab.agents.tool_use_agent.tool_use_agent import *
4+
35
# for backward compatibility of unpickling
46
sys.modules[__name__ + ".multi_tool_agent"] = sys.modules[__name__]

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def apply(self, llm, discussion: StructuredDiscussion, obs: dict) -> dict:
147147

148148
AXTREE_NOTE = """
149149
AXTree extracts most of the interactive elements of the DOM in a tree structure. It may also contain information that is not visible in the screenshot.
150-
A line starting with [bid] is a node in the AXTree. It is a unique alpha-numeric identifier to be used when calling tools.
150+
A line starting with [bid] is a node in the AXTree. It is a unique alpha-numeric identifier to be used when calling tools, e.g, click(bid="a253"). Make sure to include letters and numbers in the bid.
151151
"""
152152

153153

@@ -347,7 +347,7 @@ class PromptConfig:
347347
task_hint: TaskHint = None
348348
keep_last_n_obs: int = 1
349349
multiaction: bool = False
350-
action_subsets: tuple[str] = field(default_factory=lambda: ("coord",))
350+
action_subsets: tuple[str] = None
351351

352352

353353
@dataclass
@@ -498,6 +498,15 @@ def get_action(self, obs: Any) -> float:
498498
vision_support=True,
499499
)
500500

501+
GPT_4_1_MINI = OpenAIResponseModelArgs(
502+
model_name="gpt-4.1-mini",
503+
max_total_tokens=200_000,
504+
max_input_tokens=200_000,
505+
max_new_tokens=2_000,
506+
temperature=0.1,
507+
vision_support=True,
508+
)
509+
501510
OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs(
502511
model_name="gpt-4o-2024-08-06",
503512
max_total_tokens=200_000,

src/agentlab/analyze/agent_xray.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,13 @@ def get_screenshot(
601601
if annotate:
602602
action_str = step_info.action
603603
properties = step_info.obs.get("extra_element_properties", None)
604-
action_colored = annotate_action(img, action_string=action_str, properties=properties)
604+
try:
605+
action_colored = annotate_action(
606+
img, action_string=action_str, properties=properties
607+
)
608+
except Exception as e:
609+
warning(f"Failed to annotate action: {e}")
610+
action_colored = action_str
605611
else:
606612
action_colored = None
607613
return img, action_colored
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import os
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
5+
import pandas as pd
6+
from tqdm import tqdm
7+
8+
from agentlab.analyze import inspect_results
9+
from agentlab.experiments.exp_utils import RESULTS_DIR
10+
from agentlab.experiments.study import Study
11+
12+
13+
@dataclass
14+
class StudyInfo:
15+
study_dir: Path
16+
study: Study
17+
summary_df: pd.DataFrame
18+
should_delete: bool = False
19+
reason: str = ""
20+
21+
22+
def search_for_reasons_to_archive(result_dir: Path, min_study_size: int = 0) -> list[StudyInfo]:
23+
24+
study_info_list = []
25+
study_dirs = list(result_dir.iterdir())
26+
progress = tqdm(study_dirs, desc="Processing studies")
27+
for study_dir in progress:
28+
29+
progress.set_postfix({"study_dir": study_dir})
30+
if not study_dir.is_dir():
31+
progress.set_postfix({"status": "skipped"})
32+
continue
33+
34+
try:
35+
study = Study.load(study_dir)
36+
except Exception:
37+
study = None
38+
# get summary*.csv files and find the most recent
39+
summary_files = list(study_dir.glob("summary*.csv"))
40+
41+
if len(summary_files) != 0:
42+
most_recent_summary = max(summary_files, key=os.path.getctime)
43+
summary_df = pd.read_csv(most_recent_summary)
44+
45+
else:
46+
try:
47+
result_df = inspect_results.load_result_df(study_dir, progress_fn=None)
48+
summary_df = inspect_results.summarize_study(result_df)
49+
except Exception as e:
50+
print(f" Error processing {study_dir}: {e}")
51+
continue
52+
53+
study_info = StudyInfo(
54+
study_dir=study_dir,
55+
study=study,
56+
summary_df=summary_df,
57+
)
58+
59+
if len(study_info.summary_df) == 0:
60+
study_info.should_delete = True
61+
study_info.reason = "Empty summary DataFrame"
62+
63+
n_completed, n_total, n_err = 0, 0, 0
64+
65+
for _, row in study_info.summary_df.iterrows():
66+
n_comp, n_tot = row["n_completed"].split("/")
67+
n_completed += int(n_comp)
68+
n_total += int(n_tot)
69+
n_err += int(row.get("n_err"))
70+
71+
n_finished = n_completed - n_err
72+
73+
# print(summary_df)
74+
# print(f" {n_completed} / {n_total}, {n_err} errors")
75+
76+
if "miniwob-tiny-test" in study_dir.name:
77+
study_info.should_delete = True
78+
study_info.reason += "Miniwob tiny test\n"
79+
if n_total == 0:
80+
study_info.should_delete = True
81+
study_info.reason += "No tasks\n"
82+
if n_completed == 0:
83+
study_info.should_delete = True
84+
study_info.reason += "No tasks completed\n"
85+
if float(n_finished) / float(n_total) < 0.5:
86+
study_info.should_delete = True
87+
study_info.reason += f"Less than 50% tasks finished, n_err: {n_err}, n_total: {n_total}, n_finished: {n_finished}, n_completed: {n_completed}\n"
88+
89+
if n_total <= min_study_size:
90+
study_info.should_delete = True
91+
study_info.reason += (
92+
f"Too few tasks. n_total ({n_total}) <= min_study_size ({min_study_size})\n"
93+
)
94+
95+
study_info_list.append(study_info)
96+
return study_info_list
97+
98+
99+
if __name__ == "__main__":
100+
study_list_info = search_for_reasons_to_archive(RESULTS_DIR, min_study_size=5)
101+
archive_dir = RESULTS_DIR.parent / "archived_agentlab_results" # type: Path
102+
archive_dir.mkdir(parents=True, exist_ok=True)
103+
104+
# Uncomment the line below to prevent moving studies to archive
105+
archive_dir = None
106+
107+
for study_info in study_list_info:
108+
if not study_info.should_delete:
109+
continue
110+
111+
print(f"Study: {study_info.study_dir.name}")
112+
print(f" Reason: {study_info.reason}")
113+
print(study_info.summary_df)
114+
print()
115+
116+
if archive_dir is not None:
117+
# move to new dir
118+
new_path = archive_dir / study_info.study_dir.name
119+
study_info.study_dir.rename(new_path)
120+
# save reason in a file
121+
reason_file = new_path / "reason_to_archive.txt"
122+
reason_file.write_text(study_info.reason)

src/agentlab/experiments/graph_execution_ray.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
# import os
2-
3-
# # Disable Ray log deduplication
4-
# os.environ["RAY_DEDUP_LOGS"] = "0"
51
import logging
62
import time
73

@@ -90,12 +86,22 @@ def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_inter
9086

9187

9288
def get_elapsed_time(task_ref: ray.ObjectRef):
93-
task_id = task_ref.task_id().hex()
94-
task_info = state.get_task(task_id, address="auto")
95-
if task_info and task_info.start_time_ms is not None:
96-
start_time_s = task_info.start_time_ms / 1000.0 # Convert ms to s
89+
try:
90+
task_id = task_ref.task_id().hex()
91+
task_info = state.get_task(task_id, address="auto")
92+
if not task_info:
93+
return None
94+
if not isinstance(task_info, list):
95+
task_info = [task_info]
96+
97+
start_times_ms = [getattr(t, "start_time_ms", None) for t in task_info]
98+
start_time_s = max([t / 1000.0 if t is not None else -1 for t in start_times_ms])
99+
if start_time_s < 0:
100+
return None # Task has not started yet
101+
97102
current_time_s = time.time()
98103
elapsed_time = current_time_s - start_time_s
99104
return elapsed_time
100-
else:
101-
return None # Task has not started yet
105+
except Exception as e:
106+
logger.warning(f"Could not get elapsed time for task {task_id}: {e}")
107+
return None

src/agentlab/experiments/loop.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from PIL import Image
2626
from tqdm import tqdm
2727

28-
from agentlab.agents.tapeagent import TapeAgent, save_tape
28+
try:
29+
from agentlab.agents.tapeagent import TapeAgent, save_tape
30+
except ImportError:
31+
TapeAgent = None
32+
2933

3034
logger = logging.getLogger(__name__)
3135

@@ -474,7 +478,7 @@ def run(self):
474478
err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}"
475479
logger.info("Saving experiment info.")
476480
self.save_summary_info(episode_info, Path(self.exp_dir), err_msg, stack_trace)
477-
if isinstance(agent, TapeAgent):
481+
if TapeAgent is not None and isinstance(agent, TapeAgent):
478482
task = getattr(env, "task", {})
479483
save_tape(self.exp_dir, episode_info, task, agent.final_tape)
480484
except Exception as e:

src/agentlab/llm/chat_api.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from functools import partial
77
from typing import Optional
88

9+
import anthropic
910
import openai
1011
from huggingface_hub import InferenceClient
1112
from openai import AzureOpenAI, OpenAI
@@ -471,3 +472,77 @@ def __init__(
471472
client_args={"base_url": "http://0.0.0.0:8000/v1"},
472473
pricing_func=None,
473474
)
475+
476+
477+
class AnthropicChatModel(AbstractChatModel):
478+
def __init__(
479+
self,
480+
model_name,
481+
api_key=None,
482+
temperature=0.5,
483+
max_tokens=100,
484+
max_retry=4,
485+
log_probs=False,
486+
):
487+
self.model_name = model_name
488+
self.temperature = temperature
489+
self.max_tokens = max_tokens
490+
self.max_retry = max_retry
491+
self.log_probs = log_probs
492+
493+
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
494+
self.client = anthropic.Anthropic(api_key=api_key)
495+
496+
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
497+
# Convert OpenAI format to Anthropic format
498+
system_message = None
499+
anthropic_messages = []
500+
501+
for msg in messages:
502+
if msg["role"] == "system":
503+
system_message = msg["content"]
504+
else:
505+
anthropic_messages.append({"role": msg["role"], "content": msg["content"]})
506+
507+
temperature = temperature if temperature is not None else self.temperature
508+
509+
for attempt in range(self.max_retry):
510+
try:
511+
kwargs = {
512+
"model": self.model_name,
513+
"messages": anthropic_messages,
514+
"max_tokens": self.max_tokens,
515+
"temperature": temperature,
516+
}
517+
518+
if system_message:
519+
kwargs["system"] = system_message
520+
521+
response = self.client.messages.create(**kwargs)
522+
523+
# Track usage if available
524+
if hasattr(tracking.TRACKER, "instance"):
525+
tracking.TRACKER.instance(
526+
response.usage.input_tokens,
527+
response.usage.output_tokens,
528+
0, # cost calculation would need pricing info
529+
)
530+
531+
return AIMessage(response.content[0].text)
532+
533+
except Exception as e:
534+
if attempt == self.max_retry - 1:
535+
raise e
536+
logging.warning(f"Anthropic API error (attempt {attempt + 1}): {e}")
537+
time.sleep(60) # Simple retry delay
538+
539+
540+
@dataclass
541+
class AnthropicModelArgs(BaseModelArgs):
542+
def make_model(self):
543+
return AnthropicChatModel(
544+
model_name=self.model_name,
545+
temperature=self.temperature,
546+
max_tokens=self.max_new_tokens,
547+
log_probs=self.log_probs,
548+
)

src/agentlab/llm/llm_configs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717
]
1818

1919
CHAT_MODEL_ARGS_DICT = {
20+
"openai/gpt-4.1-mini-2025-04-14": OpenAIModelArgs(
21+
model_name="gpt-4.1-mini-2025-04-14",
22+
max_total_tokens=128_000,
23+
max_input_tokens=128_000,
24+
max_new_tokens=16_384,
25+
vision_support=True,
26+
),
27+
"openai/gpt-4.1-2025-04-14": OpenAIModelArgs(
28+
model_name="gpt-4.1-2025-04-14",
29+
max_total_tokens=128_000,
30+
max_input_tokens=128_000,
31+
max_new_tokens=16_384,
32+
vision_support=True,
33+
),
2034
"openai/o3-mini-2025-01-31": OpenAIModelArgs(
2135
model_name="o3-mini-2025-01-31",
2236
max_total_tokens=200_000,

0 commit comments

Comments
 (0)