Skip to content

Commit 463e0cf

Browse files
authored
Merge branch 'main' into add-claude-3.7
2 parents 24f48f3 + faccb16 commit 463e0cf

File tree

6 files changed

+120
-88
lines changed

6 files changed

+120
-88
lines changed

src/agentlab/agents/generic_agent/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
AGENT_4o_MINI,
1818
AGENT_CLAUDE_SONNET_35,
1919
AGENT_37_SONNET,
20+
AGENT_CLAUDE_SONNET_35_VISION,
2021
AGENT_4o_VISION,
22+
AGENT_4o_MINI_VISION,
2123
AGENT_o3_MINI,
2224
AGENT_o1_MINI,
2325
)

src/agentlab/experiments/study.py

Lines changed: 82 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from concurrent.futures import ProcessPoolExecutor
21
import gzip
32
import logging
43
import os
54
import pickle
5+
import random
66
import uuid
77
from abc import ABC, abstractmethod
8+
from concurrent.futures import ProcessPoolExecutor
89
from dataclasses import dataclass
910
from datetime import datetime
11+
from multiprocessing import Manager, Pool, Queue
1012
from pathlib import Path
1113

1214
import bgym
@@ -19,8 +21,6 @@
1921
from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies
2022
from agentlab.experiments.launch_exp import find_incomplete, non_dummy_count, run_experiments
2123
from agentlab.experiments.multi_server import BaseServer, WebArenaInstanceVars
22-
from multiprocessing import Pool, Manager, Queue
23-
import random
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -238,7 +238,7 @@ def __post_init__(self):
238238

239239
def make_exp_args_list(self):
240240
"""Generate the exp_args_list from the agent_args and the benchmark."""
241-
self.exp_args_list = _agents_on_benchmark(
241+
self.exp_args_list = self.agents_on_benchmark(
242242
self.agent_args,
243243
self.benchmark,
244244
logging_level=self.logging_level,
@@ -424,6 +424,84 @@ def load(dir: Path) -> "Study":
424424
def load_most_recent(root_dir: Path = None, contains=None) -> "Study":
425425
return Study.load(get_most_recent_study(root_dir, contains=contains))
426426

427+
def agents_on_benchmark(
428+
self,
429+
agents: list[AgentArgs] | AgentArgs,
430+
benchmark: bgym.Benchmark,
431+
demo_mode=False,
432+
logging_level: int = logging.INFO,
433+
logging_level_stdout: int = logging.INFO,
434+
ignore_dependencies=False,
435+
):
436+
"""Run one or multiple agents on a benchmark.
437+
438+
Args:
439+
agents: list[AgentArgs] | AgentArgs
440+
The agent configuration(s) to run.
441+
benchmark: bgym.Benchmark
442+
The benchmark to run the agents on.
443+
demo_mode: bool
444+
If True, the experiments will be run in demo mode.
445+
logging_level: int
446+
The logging level for individual jobs.
447+
logging_level_stdout: int
448+
The logging level for the stdout.
449+
ignore_dependencies: bool
450+
If True, the dependencies will be ignored and all experiments can be run in parallel.
451+
452+
Returns:
453+
list[ExpArgs]: The list of experiments to run.
454+
455+
Raises:
456+
ValueError: If multiple agents are run on a benchmark that requires manual reset.
457+
"""
458+
459+
if not isinstance(agents, (list, tuple)):
460+
agents = [agents]
461+
462+
if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"):
463+
if len(agents) > 1:
464+
raise ValueError(
465+
f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation."
466+
)
467+
468+
for agent in agents:
469+
agent.set_benchmark(
470+
benchmark, demo_mode
471+
) # the agent can adapt (lightly?) to the benchmark
472+
473+
env_args_list = benchmark.env_args_list
474+
if demo_mode:
475+
set_demo_mode(env_args_list)
476+
477+
exp_args_list = []
478+
479+
for agent in agents:
480+
for env_args in env_args_list:
481+
exp_args = ExpArgs(
482+
agent_args=agent,
483+
env_args=env_args,
484+
logging_level=logging_level,
485+
logging_level_stdout=logging_level_stdout,
486+
)
487+
exp_args_list.append(exp_args)
488+
489+
for i, exp_args in enumerate(exp_args_list):
490+
exp_args.order = i
491+
492+
# not required with ray, but keeping around if we would need it for visualwebareana on joblib
493+
# _flag_sequential_exp(exp_args_list, benchmark)
494+
495+
if not ignore_dependencies:
496+
# populate the depends_on field based on the task dependencies in the benchmark
497+
exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks())
498+
else:
499+
logger.warning(
500+
f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results."
501+
)
502+
503+
return exp_args_list
504+
427505

428506
def _make_study_name(agent_names, benchmark_names, suffix=None):
429507
"""Make a study name from the agent and benchmark names."""
@@ -634,82 +712,6 @@ def set_demo_mode(env_args_list: list[EnvArgs]):
634712
env_args.slow_mo = 1000
635713

636714

637-
def _agents_on_benchmark(
638-
agents: list[AgentArgs] | AgentArgs,
639-
benchmark: bgym.Benchmark,
640-
demo_mode=False,
641-
logging_level: int = logging.INFO,
642-
logging_level_stdout: int = logging.INFO,
643-
ignore_dependencies=False,
644-
):
645-
"""Run one or multiple agents on a benchmark.
646-
647-
Args:
648-
agents: list[AgentArgs] | AgentArgs
649-
The agent configuration(s) to run.
650-
benchmark: bgym.Benchmark
651-
The benchmark to run the agents on.
652-
demo_mode: bool
653-
If True, the experiments will be run in demo mode.
654-
logging_level: int
655-
The logging level for individual jobs.
656-
logging_level_stdout: int
657-
The logging level for the stdout.
658-
ignore_dependencies: bool
659-
If True, the dependencies will be ignored and all experiments can be run in parallel.
660-
661-
Returns:
662-
list[ExpArgs]: The list of experiments to run.
663-
664-
Raises:
665-
ValueError: If multiple agents are run on a benchmark that requires manual reset.
666-
"""
667-
668-
if not isinstance(agents, (list, tuple)):
669-
agents = [agents]
670-
671-
if benchmark.name.startswith("visualwebarena") or benchmark.name.startswith("webarena"):
672-
if len(agents) > 1:
673-
raise ValueError(
674-
f"Only one agent can be run on {benchmark.name} since the instance requires manual reset after each evaluation."
675-
)
676-
677-
for agent in agents:
678-
agent.set_benchmark(benchmark, demo_mode) # the agent can adapt (lightly?) to the benchmark
679-
680-
env_args_list = benchmark.env_args_list
681-
if demo_mode:
682-
set_demo_mode(env_args_list)
683-
684-
exp_args_list = []
685-
686-
for agent in agents:
687-
for env_args in env_args_list:
688-
exp_args = ExpArgs(
689-
agent_args=agent,
690-
env_args=env_args,
691-
logging_level=logging_level,
692-
logging_level_stdout=logging_level_stdout,
693-
)
694-
exp_args_list.append(exp_args)
695-
696-
for i, exp_args in enumerate(exp_args_list):
697-
exp_args.order = i
698-
699-
# not required with ray, but keeping around if we would need it for visualwebareana on joblib
700-
# _flag_sequential_exp(exp_args_list, benchmark)
701-
702-
if not ignore_dependencies:
703-
# populate the depends_on field based on the task dependencies in the benchmark
704-
exp_args_list = add_dependencies(exp_args_list, benchmark.dependency_graph_over_tasks())
705-
else:
706-
logger.warning(
707-
f"Ignoring dependencies for benchmark {benchmark.name}. This could lead to different results."
708-
)
709-
710-
return exp_args_list
711-
712-
713715
# def _flag_sequential_exp(exp_args_list: list[ExpArgs], benchmark: Benchmark):
714716
# if benchmark.name.startswith("visualwebarena"):
715717
# sequential_subset = benchmark.subset_from_glob("requires_reset", "True")

src/agentlab/llm/base_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class BaseModelArgs(ABC):
2121
max_new_tokens: int = None
2222
temperature: float = 0.1
2323
vision_support: bool = False
24+
log_probs: bool = False
2425

2526
@abstractmethod
2627
def make_model(self) -> AbstractChatModel:

src/agentlab/llm/chat_api.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def make_model(self):
8787
model_name=self.model_name,
8888
temperature=self.temperature,
8989
max_tokens=self.max_new_tokens,
90+
log_probs=self.log_probs,
9091
)
9192

9293

@@ -100,6 +101,7 @@ def make_model(self):
100101
model_name=self.model_name,
101102
temperature=self.temperature,
102103
max_tokens=self.max_new_tokens,
104+
log_probs=self.log_probs,
103105
)
104106

105107

@@ -115,6 +117,7 @@ def make_model(self):
115117
temperature=self.temperature,
116118
max_tokens=self.max_new_tokens,
117119
deployment_name=self.deployment_name,
120+
log_probs=self.log_probs,
118121
)
119122

120123

@@ -142,6 +145,7 @@ def make_model(self):
142145
temperature=self.temperature,
143146
max_new_tokens=self.max_new_tokens,
144147
n_retry_server=self.n_retry_server,
148+
log_probs=self.log_probs,
145149
)
146150
elif self.backend == "vllm":
147151
return VLLMChatModel(
@@ -232,6 +236,7 @@ def __init__(
232236
client_class=OpenAI,
233237
client_args=None,
234238
pricing_func=None,
239+
log_probs=False,
235240
):
236241
assert max_retry > 0, "max_retry should be greater than 0"
237242

@@ -240,6 +245,7 @@ def __init__(
240245
self.max_tokens = max_tokens
241246
self.max_retry = max_retry
242247
self.min_retry_wait_time = min_retry_wait_time
248+
self.log_probs = log_probs
243249

244250
# Get the API key from the environment variable if not provided
245251
if api_key_env_var:
@@ -286,6 +292,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
286292
n=n_samples,
287293
temperature=temperature,
288294
max_tokens=self.max_tokens,
295+
logprobs=self.log_probs,
289296
)
290297

291298
if completion.usage is None:
@@ -315,7 +322,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
315322
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
316323

317324
if n_samples == 1:
318-
return AIMessage(completion.choices[0].message.content)
325+
res = AIMessage(completion.choices[0].message.content)
326+
if self.log_probs:
327+
res["log_probs"] = completion.choices[0].log_probs
328+
return res
319329
else:
320330
return [AIMessage(c.message.content) for c in completion.choices]
321331

@@ -335,6 +345,7 @@ def __init__(
335345
max_tokens=100,
336346
max_retry=4,
337347
min_retry_wait_time=60,
348+
log_probs=False,
338349
):
339350
super().__init__(
340351
model_name=model_name,
@@ -346,6 +357,7 @@ def __init__(
346357
api_key_env_var="OPENAI_API_KEY",
347358
client_class=OpenAI,
348359
pricing_func=tracking.get_pricing_openai,
360+
log_probs=log_probs,
349361
)
350362

351363

@@ -358,6 +370,7 @@ def __init__(
358370
max_tokens=100,
359371
max_retry=4,
360372
min_retry_wait_time=60,
373+
log_probs=False,
361374
):
362375
client_args = {
363376
"base_url": "https://openrouter.ai/api/v1",
@@ -373,6 +386,7 @@ def __init__(
373386
client_class=OpenAI,
374387
client_args=client_args,
375388
pricing_func=tracking.get_pricing_openrouter,
389+
log_probs=log_probs,
376390
)
377391

378392

@@ -386,6 +400,7 @@ def __init__(
386400
max_tokens=100,
387401
max_retry=4,
388402
min_retry_wait_time=60,
403+
log_probs=False,
389404
):
390405
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
391406
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
@@ -406,6 +421,7 @@ def __init__(
406421
client_class=AzureOpenAI,
407422
client_args=client_args,
408423
pricing_func=tracking.get_pricing_openai,
424+
log_probs=log_probs,
409425
)
410426

411427

@@ -419,8 +435,9 @@ def __init__(
419435
temperature: Optional[int] = 1e-1,
420436
max_new_tokens: Optional[int] = 512,
421437
n_retry_server: Optional[int] = 4,
438+
log_probs: Optional[bool] = False,
422439
):
423-
super().__init__(model_name, base_model_name, n_retry_server)
440+
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
424441
if temperature < 1e-3:
425442
logging.warning("Models might behave weirdly when temperature is too low.")
426443
self.temperature = temperature
@@ -429,7 +446,7 @@ def __init__(
429446
token = os.environ["TGI_TOKEN"]
430447

431448
client = InferenceClient(model=model_url, token=token)
432-
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
449+
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs)
433450

434451

435452
class VLLMChatModel(ChatModel):

src/agentlab/llm/huggingface_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ class HFBaseChatModel(AbstractChatModel):
4040
description="The number of times to retry the server if it fails to respond",
4141
)
4242

43-
def __init__(self, model_name, base_model_name, n_retry_server):
43+
def __init__(self, model_name, base_model_name, n_retry_server, log_probs):
4444
super().__init__()
4545
self.n_retry_server = n_retry_server
46+
self.log_probs = log_probs
4647

4748
if base_model_name is None:
4849
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -100,7 +101,11 @@ def __call__(
100101
while True:
101102
try:
102103
temperature = temperature if temperature is not None else self.temperature
103-
response = AIMessage(self.llm(prompt, temperature=temperature))
104+
answer = self.llm(prompt, temperature=temperature)
105+
response = AIMessage(answer)
106+
if self.log_probs:
107+
response["content"] = answer.generated_text
108+
response["log_probs"] = answer.details
104109
responses.append(response)
105110
break
106111
except Exception as e:

0 commit comments

Comments
 (0)