Skip to content

Commit 16af31b

Browse files
Enhance ChatModel to support reasoning extraction in responses
1 parent 5bd690e commit 16af31b

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,19 +322,35 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
322322
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
323323

324324
if n_samples == 1:
325-
res = AIMessage(completion.choices[0].message.content)
325+
res = AIMessage(self.extract_content_with_reasoning(completion.choices[0].message))
326326
if self.log_probs:
327327
res["log_probs"] = completion.choices[0].log_probs
328328
return res
329329
else:
330-
return [AIMessage(c.message.content) for c in completion.choices]
330+
return [
331+
AIMessage(self.extract_content_with_reasoning(c.message))
332+
for c in completion.choices
333+
]
331334

332335
def get_stats(self):
333336
return {
334337
"n_retry_llm": self.retries,
335338
# "busted_retry_llm": int(not self.success), # not logged if it occurs anyways
336339
}
337340

341+
# Support for models that return reasoning.
342+
def extract_content_with_reasoning(self, message, wrap_tag="think"):
343+
"""Extracts the content from the message, including reasoning if available.
344+
It wraps the reasoning around <think>...</think> for backward compatibility."""
345+
346+
reasoning_content = getattr(message, "reasoning", None)
347+
if reasoning_content:
348+
# Wrap reasoning in <think> tags with newlines for clarity
349+
reasoning_content = f"<{wrap_tag}>\n{reasoning_content}\n</{wrap_tag}>\n"
350+
else:
351+
reasoning_content = ""
352+
return f"{reasoning_content}{message.content}"
353+
338354

339355
class OpenAIChatModel(ChatModel):
340356
def __init__(

0 commit comments

Comments
 (0)