Skip to content

Commit 462038e

Browse files
committed
max obs size limit, function to prepare pair of turn data for rl training
1 parent 1befd83 commit 462038e

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

src/agentlab/agents/react_toolcall_agent.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class AgentConfig:
4444
use_axtree: bool = False
4545
use_screenshot: bool = True
4646
max_actions: int = 10
47+
max_obs_chars: int = 100000 # truncate long observations to N chars
4748
max_history_tokens: int = 120000
4849
system_prompt: str = """
4950
You are an expert AI Agent trained to assist users with complex web tasks.
@@ -113,7 +114,7 @@ def obs_to_messages(self, obs: dict) -> list[dict]:
113114
goal = goal_obj[0]["text"]
114115
messages.append(user_message(f"Goal: {goal}"))
115116

116-
text = "\n\n".join([f"## {k}\n{v}" for k, v in texts.items()])
117+
text = "\n\n".join([f"## {k}\n{v}" for k, v in texts.items()])[:self.config.max_obs_chars]
117118
if self.last_tool_call_id:
118119
message = {
119120
"role": "tool",
@@ -182,6 +183,7 @@ def action_from_message(self, message: Message) -> ToolCall:
182183
logger.warning("Multiple tool calls found in LLM response, using the first one.")
183184
tool_call = message.tool_calls[0]
184185
name = tool_call.function.name
186+
assert name, "Tool call must have a name."
185187
args = json.loads(tool_call.function.arguments)
186188
action = ToolCall(id=tool_call.id, name=name, arguments=args)
187189
self.last_tool_call_id = action.id
@@ -213,7 +215,7 @@ def compact_history(self):
213215
]
214216

215217
try:
216-
response = self.llm(messages=messages, tool_choice="none")
218+
response = self.llm(messages=messages)
217219
summary = response.choices[0].message.content # type: ignore
218220
except Exception as e:
219221
logger.exception(f"Error compacting history: {e}")
@@ -224,11 +226,19 @@ def compact_history(self):
224226
summary_message = {"role": "user", "content": f"## Previous Interaction :\n{summary}"}
225227
self.history = [system_msg, summary_message, *rest[midpoint:]]
226228

229+
def get_training_pairs(self) -> list[tuple[list[dict | Message], Message]]:
230+
input_output_pairs = []
231+
prev_history = []
232+
for msg in self.history:
233+
if isinstance(msg, Message):
234+
input_output_pairs.append((prev_history, msg))
235+
prev_history.append(msg)
236+
return input_output_pairs
227237

228238
@dataclass
229239
class ReactToolCallAgentArgs(AgentArgs):
230-
llm_args: LLMArgs | None = None
231-
config: AgentConfig | None = None
240+
llm_args: LLMArgs = None # type: ignore
241+
config: AgentConfig = None # type: ignore
232242

233243
def make_agent(self, actions: list[ToolSpec]) -> ReactToolCallAgent:
234244
llm = self.llm_args.make_model()

0 commit comments

Comments
 (0)