Skip to content

Commit 24ec7a7

Browse files
committed
Merge branch 'tlsdc/log_prob' of github.com:ServiceNow/AgentLab into tlsdc/log_prob
2 parents abb44c7 + f666adb commit 24ec7a7

File tree

7 files changed

+60
-14
lines changed

7 files changed

+60
-14
lines changed

add_study_to_repro_journal.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import os
2+
from pathlib import Path
3+
from agentlab.experiments.study import Study
4+
5+
6+
base_dir = "/home/toolkit/ui_copilot_results"
7+
8+
exp_paths = [
9+
"2025-01-31_22-08-34_genericagent-o3-mini-2025-01-31-on-workarena-l1",
10+
# '2025-02-02_01-53-45_genericagent-openai-o1-mini-2024-09-12-on-workarena-l1',
11+
"2025-02-02_01-55-04_genericagent-openai-o1-mini-2024-09-12-on-workarena-l1",
12+
]
13+
full_paths = [os.path.join(base_dir, exp_path) for exp_path in exp_paths]
14+
15+
for full_path in full_paths:
16+
study = Study.load(Path(full_path))
17+
18+
study.append_to_journal(strict_reproducibility=False)

reproducibility_journal.csv

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,12 @@ ThibaultLSDC,GenericAgent-gpt-4o-mini_vision,visualwebarena,0.13.3,2024-12-02_02
6464
ThibaultLSDC,GenericAgent-gpt-4o_vision,visualwebarena,0.13.3,2024-12-02_07-17-28,7fb7eac8-4bbd-4ebe-be32-15901a7678f2,0.267,0.015,65,910/910,None,Linux (#68-Ubuntu SMP Mon Oct 7 14:34:20 UTC 2024),3.12.7,1.39.0,0.3.1,df7bc706f3793f47a456d1bda0485b306b8cf612,,0.13.3,None,
6565
ThibaultLSDC,GenericAgent-anthropic_claude-3.5-sonnet:beta_vision,visualwebarena,0.13.3,2024-12-02_09-11-35,22f0611d-aeea-4ee9-a533-b45442b5e080,0.21,0.013,178,910/910,None,Linux (#68-Ubuntu SMP Mon Oct 7 14:34:20 UTC 2024),3.12.7,1.39.0,0.3.1,df7bc706f3793f47a456d1bda0485b306b8cf612,,0.13.3,None,
6666
ThibaultLSDC,GenericAgent-meta-llama_llama-3.1-70b-instruct,webarena,0.13.3,2024-12-02_23-18-38,fc5747bc-d998-4942-a0eb-e55a3ccc1cb3,0.184,0.014,213,811/812,None,Linux (#68-Ubuntu SMP Mon Oct 7 14:34:20 UTC 2024),3.12.7,1.39.0,0.3.1,df7bc706f3793f47a456d1bda0485b306b8cf612,,0.13.3,None,
67-
67+
Leo Boisvert,GenericAgent-o3-mini-2025-01-31,workarena_l1,0.4.1,2025-01-31_22-08-33,a74cc00f-f743-43a1-9cab-59af8bffa3a2,0.482,0.028,3,330/330,None,Linux (#68-Ubuntu SMP Mon Oct 7 14:34:20 UTC 2024),3.12.3,1.44.0,v0.3.2,73baabee6d7ac37a5b8677c80baf83914a4f4dc4," M: src/agentlab/agents/generic_agent/__init__.py
68+
M: src/agentlab/agents/generic_agent/agent_configs.py
69+
M: src/agentlab/analyze/agent_xray.py
70+
M: src/agentlab/llm/chat_api.py
71+
M: src/agentlab/llm/llm_configs.py",0.13.3,1d2d7160e5b7ec9954ecb48988f71eb56288dd29,"
72+
Leo Boisvert,GenericAgent-openai_o1-mini-2024-09-12,workarena_l1,0.4.1,2025-02-02_01-55-04,f3e1fcb8-5fc5-4115-9e00-27251508e2c7,0.518,0.028,5,330/330,None,Linux (#68-Ubuntu SMP Mon Oct 7 14:34:20 UTC 2024),3.12.3,1.44.0,v0.3.2,73baabee6d7ac37a5b8677c80baf83914a4f4dc4," M: src/agentlab/agents/generic_agent/__init__.py
73+
M: src/agentlab/agents/generic_agent/agent_configs.py
74+
M: src/agentlab/analyze/agent_xray.py
75+
M: src/agentlab/llm/llm_configs.py",0.13.3,1d2d7160e5b7ec9954ecb48988f71eb56288dd29,"

src/agentlab/agents/generic_agent/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717
AGENT_4o_MINI,
1818
AGENT_CLAUDE_SONNET_35,
1919
AGENT_4o_VISION,
20-
AGENT_4o_MINI_VISION,
21-
AGENT_CLAUDE_SONNET_35_VISION,
20+
AGENT_o3_MINI,
21+
AGENT_o1_MINI,
2222
)
2323

2424
__all__ = [
2525
"AGENT_3_5",
2626
"AGENT_4o",
2727
"AGENT_4o_MINI",
2828
"AGENT_4o_VISION",
29+
"AGENT_o3_MINI",
30+
"AGENT_o1_MINI",
2931
"AGENT_LLAMA3_70B",
3032
"AGENT_LLAMA31_70B",
3133
"AGENT_8B",

src/agentlab/agents/generic_agent/agent_configs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,15 @@
265265
flags=FLAGS_GPT_4o,
266266
)
267267

268+
AGENT_o3_MINI = GenericAgentArgs(
269+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/o3-mini-2025-01-31"],
270+
flags=FLAGS_GPT_4o,
271+
)
272+
273+
AGENT_o1_MINI = GenericAgentArgs(
274+
chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/openai/o1-mini-2024-09-12"],
275+
flags=FLAGS_GPT_4o,
276+
)
268277
# GPT-4o vision default config
269278
FLAGS_GPT_4o_VISION = FLAGS_GPT_4o.copy()
270279
FLAGS_GPT_4o_VISION.obs.use_screenshot = True

src/agentlab/llm/chat_api.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def make_model(self):
145145
temperature=self.temperature,
146146
max_new_tokens=self.max_new_tokens,
147147
n_retry_server=self.n_retry_server,
148+
log_probs=self.log_probs
148149
)
149150
else:
150151
raise ValueError(f"Backend {self.backend} is not supported")
@@ -237,7 +238,7 @@ def __init__(
237238
self.max_tokens = max_tokens
238239
self.max_retry = max_retry
239240
self.min_retry_wait_time = min_retry_wait_time
240-
self.logprobs = log_probs
241+
self.log_probs = log_probs
241242

242243
# Get the API key from the environment variable if not provided
243244
if api_key_env_var:
@@ -284,7 +285,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
284285
n=n_samples,
285286
temperature=temperature,
286287
max_tokens=self.max_tokens,
287-
logprobs=self.logprobs,
288+
log_probs=self.log_probs,
288289
)
289290

290291
if completion.usage is None:
@@ -315,8 +316,8 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
315316

316317
if n_samples == 1:
317318
res = AIMessage(completion.choices[0].message.content)
318-
if self.logprobs:
319-
res["logprobs"] = completion.choices[0].logprobs
319+
if self.log_probs:
320+
res["log_probs"] = completion.choices[0].log_probs
320321
return res
321322
else:
322323
return [AIMessage(c.message.content) for c in completion.choices]
@@ -429,7 +430,7 @@ def __init__(
429430
n_retry_server: Optional[int] = 4,
430431
log_probs: Optional[bool] = False,
431432
):
432-
super().__init__(model_name, base_model_name, n_retry_server)
433+
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
433434
if temperature < 1e-3:
434435
logging.warning("Models might behave weirdly when temperature is too low.")
435436
self.temperature = temperature

src/agentlab/llm/huggingface_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import time
33
from typing import Any, List, Optional, Union
44

5-
from pydantic import Field
6-
from transformers import AutoTokenizer, GPT2TokenizerFast
7-
85
from agentlab.llm.base_api import AbstractChatModel
96
from agentlab.llm.llm_utils import AIMessage, Discussion
107
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template
8+
from pydantic import Field
9+
from transformers import AutoTokenizer, GPT2TokenizerFast
1110

1211

1312
class HFBaseChatModel(AbstractChatModel):
@@ -40,9 +39,10 @@ class HFBaseChatModel(AbstractChatModel):
4039
description="The number of times to retry the server if it fails to respond",
4140
)
4241

43-
def __init__(self, model_name, base_model_name, n_retry_server):
42+
def __init__(self, model_name, base_model_name, n_retry_server, log_probs):
4443
super().__init__()
4544
self.n_retry_server = n_retry_server
45+
self.log_probs = log_probs
4646

4747
if base_model_name is None:
4848
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -102,8 +102,9 @@ def __call__(
102102
temperature = temperature if temperature is not None else self.temperature
103103
answer = self.llm(prompt, temperature=temperature)
104104
response = AIMessage(answer)
105-
if hasattr(answer, "details"):
106-
response["log_prob"] = answer.details.log_prob
105+
if self.log_probs:
106+
response["content"] = answer.generated_text
107+
response["log_prob"] = answer.details
107108
responses.append(response)
108109
break
109110
except Exception as e:

src/agentlab/llm/llm_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@
6363
max_input_tokens=16_384,
6464
max_new_tokens=4096,
6565
),
66+
"openai/o1-mini": OpenAIModelArgs(
67+
model_name="openai/o1-mini",
68+
max_total_tokens=128_000,
69+
max_input_tokens=128_000,
70+
max_new_tokens=64_000,
71+
temperature=1e-1,
72+
),
6673
"azure/gpt-35-turbo/gpt-35-turbo": AzureModelArgs(
6774
model_name="gpt-35-turbo",
6875
deployment_name="gpt-35-turbo",

0 commit comments

Comments
 (0)