Skip to content

Commit b54baa6

Browse files
committed
Merge remote-tracking branch 'origin/corellm_evals' into wa_verified
2 parents 121453f + 2a025fe commit b54baa6

File tree

8 files changed

+872
-20
lines changed

8 files changed

+872
-20
lines changed

src/agentlab/agents/dynamic_prompting.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -573,9 +573,9 @@ class SystemPrompt(PromptElement):
573573
class ActionPrompt(PromptElement):
574574

575575
_concrete_ex = """
576-
<action>
576+
[BEGIN FINAL RESPONSE]
577577
click('a324')
578-
</action>
578+
[END FINAL RESPONSE]
579579
"""
580580

581581
def __init__(self, action_set: AbstractActionSet, action_flags: ActionFlags) -> None:
@@ -596,9 +596,9 @@ def __init__(self, action_set: AbstractActionSet, action_flags: ActionFlags) ->
596596
f"# Action space:\n{action_set_generic_info}{action_description}{MacNote().prompt}\n"
597597
)
598598
self._abstract_ex = f"""
599-
<action>
599+
[BEGIN FINAL RESPONSE]
600600
{self.action_set.example_action(abstract=True)}
601-
</action>
601+
[END FINAL RESPONSE]
602602
"""
603603

604604
# self._concrete_ex = f"""
@@ -789,7 +789,7 @@ def _prompt(self) -> str:
789789
prompt += f"\n<think>\n{self.thought}\n</think>\n"
790790

791791
if self.flags.use_action_history:
792-
prompt += f"\n<action>\n{self.action}\n</action>\n"
792+
prompt += f"\n[BEGIN FINAL RESPONSE]\n{self.action}\n[END FINAL RESPONSE]\n"
793793

794794
# prompt += f"{self.error.prompt}{self.html_diff.prompt}{self.ax_tree_diff.prompt}"
795795
prompt += f"{self.error.prompt}"

src/agentlab/agents/generic_agent/generic_agent_prompt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,12 @@ def shrink(self):
153153
self.history.shrink()
154154
self.obs.shrink()
155155

156-
def _parse_answer(self, text_answer):
156+
def _parse_answer(self, text_think, text_answer):
157157
ans_dict = {}
158-
ans_dict.update(self.think.parse_answer(text_answer))
159-
ans_dict.update(self.plan.parse_answer(text_answer))
160-
ans_dict.update(self.memory.parse_answer(text_answer))
161-
ans_dict.update(self.criticise.parse_answer(text_answer))
158+
ans_dict.update(self.think.parse_answer(text_think))
159+
ans_dict.update(self.plan.parse_answer(text_think))
160+
ans_dict.update(self.memory.parse_answer(text_think))
161+
ans_dict.update(self.criticise.parse_answer(text_think))
162162
ans_dict.update(self.action_prompt.parse_answer(text_answer))
163163
return ans_dict
164164

src/agentlab/llm/chat_api.py

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,86 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
324324
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
325325

326326
if n_samples == 1:
327-
res = AIMessage(completion.choices[0].message.content)
327+
think, action = self._extract_thinking_content_from_response(completion)
328+
res_think = AIMessage(think or "")
329+
res_action = AIMessage(action or "")
328330
if self.log_probs:
329-
res["log_probs"] = completion.choices[0].log_probs
330-
return res
331+
res_think["log_probs"] = completion.choices[0].logprobs
332+
return res_think, res_action
331333
else:
332-
return [AIMessage(c.message.content) for c in completion.choices]
334+
return [
335+
self._build_think_action_pair(choice)
336+
for choice in completion.choices
337+
]
338+
339+
def _extract_thinking_content_from_response(self, response, wrap_tag="think") -> tuple[str, str]:
340+
"""Extract reasoning and action content from an API response.
341+
342+
Handles multiple formats:
343+
1. OpenAI/DeepSeek: reasoning in 'reasoning_content' or 'reasoning' field
344+
2. Apriel: reasoning before [BEGIN FINAL RESPONSE]...[END FINAL RESPONSE] tags
345+
3. Standard: content as-is
346+
347+
Args:
348+
response: The API response object.
349+
wrap_tag: Tag name to wrap reasoning content (default: "think").
350+
351+
Returns:
352+
tuple: (reasoning_wrapped, action_wrapped)
353+
"""
354+
message = response.choices[0].message
355+
msg_dict = message.to_dict() if hasattr(message, 'to_dict') else dict(message)
356+
357+
reasoning = msg_dict.get("reasoning_content") or msg_dict.get("reasoning")
358+
content = msg_dict.get("content", "") or msg_dict.get("text", "")
359+
360+
# Case 1: Explicit reasoning field from API
361+
if reasoning:
362+
reasoning_wrapped = f"<{wrap_tag}>{reasoning}</{wrap_tag}>\n"
363+
if "[BEGIN FINAL RESPONSE]" in content and "[END FINAL RESPONSE]" in content:
364+
action = self._extract_last_action_from_tags(content)
365+
action_wrapped = f"<action>\n{action}\n</action>"
366+
else:
367+
action_wrapped = content
368+
return reasoning_wrapped, action_wrapped
369+
370+
# Case 2: Apriel-style format in content
371+
if "[BEGIN FINAL RESPONSE]" in content:
372+
reasoning_text, action_text = self._parse_apriel_format(content)
373+
reasoning_wrapped = f"<{wrap_tag}>\n{reasoning_text}\n</{wrap_tag}>" if reasoning_text else ""
374+
action_wrapped = f"<action>\n{action_text}\n</action>" if action_text else ""
375+
return reasoning_wrapped, action_wrapped
376+
377+
# Case 3: No special format
378+
return "", content
379+
380+
def _extract_last_action_from_tags(self, content: str) -> str:
381+
"""Extract content from the LAST [BEGIN FINAL RESPONSE]...[END FINAL RESPONSE] block."""
382+
pattern = r'\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]'
383+
matches = re.findall(pattern, content, re.DOTALL)
384+
return matches[-1].strip() if matches else ""
385+
386+
def _parse_apriel_format(self, content: str) -> tuple[str, str]:
387+
"""Parse Apriel format: reasoning before [BEGIN FINAL RESPONSE] tags."""
388+
last_begin = content.rfind("[BEGIN FINAL RESPONSE]")
389+
if last_begin == -1:
390+
return "", content
391+
392+
reasoning = content[:last_begin].strip()
393+
if reasoning.startswith("Here are my reasoning steps:"):
394+
reasoning = reasoning[len("Here are my reasoning steps:"):].strip()
395+
396+
action = self._extract_last_action_from_tags(content)
397+
return reasoning, action
398+
399+
def _build_think_action_pair(self, choice) -> tuple[AIMessage, AIMessage]:
400+
"""Build (think, action) pair from a single choice."""
401+
# Create minimal response-like object for the extraction method
402+
mock_response = type('MockResponse', (), {
403+
'choices': [choice]
404+
})()
405+
think, action = self._extract_thinking_content_from_response(mock_response)
406+
return AIMessage(think or ""), AIMessage(action or "")
333407

334408
def get_stats(self):
335409
return {
@@ -484,6 +558,55 @@ def __init__(
484558
)
485559

486560

561+
class AprielChatModel(ChatModel):
562+
"""Chat model for Apriel models hosted on DGX Cloud."""
563+
564+
def __init__(
565+
self,
566+
model_name="Slam-15B",
567+
api_key=None,
568+
base_url=None,
569+
temperature=0.5,
570+
max_tokens=15000,
571+
max_retry=4,
572+
min_retry_wait_time=60,
573+
):
574+
base_url = base_url or os.getenv(
575+
"APRIEL_API_URL",
576+
""
577+
)
578+
api_key = api_key or os.getenv("APRIEL_API_KEY")
579+
580+
super().__init__(
581+
model_name=model_name,
582+
api_key=api_key,
583+
temperature=temperature,
584+
max_tokens=max_tokens,
585+
max_retry=max_retry,
586+
min_retry_wait_time=min_retry_wait_time,
587+
client_class=OpenAI,
588+
client_args={"base_url": base_url},
589+
pricing_func=None,
590+
)
591+
592+
593+
@dataclass
594+
class AprielModelArgs(BaseModelArgs):
595+
"""Serializable args for Apriel models."""
596+
597+
base_url: str = None
598+
api_key: str = None
599+
600+
def make_model(self):
601+
return AprielChatModel(
602+
model_name=self.model_name,
603+
base_url=self.base_url,
604+
api_key=self.api_key,
605+
temperature=self.temperature,
606+
max_tokens=self.max_new_tokens,
607+
)
608+
609+
487610
class AnthropicChatModel(AbstractChatModel):
488611
def __init__(
489612
self,

src/agentlab/llm/llm_configs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
OpenAIModelArgs,
88
OpenRouterModelArgs,
99
SelfHostedModelArgs,
10+
AprielModelArgs
1011
)
1112

1213
default_oss_llms_args = {
@@ -375,4 +376,13 @@
375376
max_new_tokens=4_000,
376377
temperature=1e-1,
377378
),
379+
380+
"apriel/slam-15b": AprielModelArgs(
381+
model_name="openai/Slam-15B",
382+
base_url="",
383+
api_key="",
384+
max_total_tokens=40_000,
385+
max_new_tokens=15_000,
386+
temperature=0.6,
387+
),
378388
}

src/agentlab/llm/llm_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,19 @@ def retry(
8383
"""
8484
tries = 0
8585
while tries < n_retry:
86-
answer = chat(messages)
86+
think, action = chat(messages)
87+
think_content, action_content = think["content"], action["content"]
88+
8789
# TODO: could we change this to not use inplace modifications ?
88-
messages.append(answer)
90+
messages.append({"role": "assistant", "content": think_content + action_content})
91+
8992
try:
90-
return parser(answer["content"])
93+
return parser(think_content, action_content)
9194
except ParseError as parsing_error:
9295
tries += 1
9396
if log:
94-
msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer['content']}\n[User]:\n{str(parsing_error)}"
95-
logging.info(msg)
96-
messages.append(dict(role="user", content=str(parsing_error)))
97+
logging.info(f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{action_content}\n[User]:\n{parsing_error}")
98+
messages.append({"role": "user", "content": str(parsing_error)})
9799

98100
raise ParseError(f"Could not parse a valid value after {n_retry} retries.")
99101

src/agentlab/llm/logging_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import logging
2+
import sys
3+
from functools import lru_cache
4+
5+
6+
@lru_cache(maxsize=None)
7+
def setup_logging(level=logging.INFO):
8+
"""Configure logging once and cache the result.
9+
10+
Using lru_cache ensures this only runs once per process,
11+
even if imported and called multiple times.
12+
"""
13+
# Remove any existing handlers to avoid duplicates
14+
root = logging.getLogger()
15+
for handler in root.handlers:
16+
root.removeHandler(handler)
17+
18+
# Configure format and handler
19+
formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
20+
console_handler = logging.StreamHandler(sys.stdout)
21+
console_handler.setFormatter(formatter)
22+
23+
# Set up root logger
24+
root.addHandler(console_handler)
25+
root.setLevel(level)
26+
27+
return root
28+
29+
30+
# Call it once when module is imported
31+
logger = setup_logging()

0 commit comments

Comments
 (0)