Skip to content

Commit 3ab66f2

Browse files
add ability to save with same format as agentlab-xray
1 parent 200a8bd commit 3ab66f2

File tree

2 files changed

+100
-41
lines changed

2 files changed

+100
-41
lines changed

src/agentlab/analyze/agent_controller.py

Lines changed: 87 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
import json
55
import logging
66
import os
7+
import pickle
78
from collections import Counter
89
from datetime import datetime
910
from io import BytesIO
11+
from pathlib import Path
1012

1113
import numpy as np
1214
import PIL.Image
1315
import requests
1416
import streamlit as st
1517
from agentlab.agents.generic_agent import __all__ as ALL_AGENTS
1618
from agentlab.experiments.exp_utils import RESULTS_DIR
19+
from agentlab.experiments.loop import ExpArgs, StepInfo, save_package_versions
1720
from agentlab.llm.llm_utils import Discussion
1821
from bgym import DEFAULT_BENCHMARKS
1922
from dotenv import load_dotenv
@@ -133,6 +136,12 @@ def reset_env_history():
133136
st.session_state.screenshot_history = []
134137
st.session_state.axtree_history = []
135138

139+
# related to env info
140+
st.session_state.reward_history = []
141+
st.session_state.terminated_history = []
142+
st.session_state.truncated_history = []
143+
st.session_state.env_info_history = []
144+
136145

137146
def reset_agent_history():
138147
logger.info("Resetting agent history")
@@ -150,13 +159,19 @@ def reset_agent_state():
150159
st.session_state.agent.reset()
151160

152161

153-
def step_env_history(obs):
162+
def step_env_history(obs, response_json):
154163
logger.info("Stepping env history")
155164
st.session_state.last_obs = copy.deepcopy(obs)
156165
st.session_state.obs_history.append(obs)
157166
st.session_state.screenshot_history.append(obs[Constants.SCREENSHOT])
158167
st.session_state.axtree_history.append(obs[Constants.AXTREE_TXT])
159168

169+
# other relevant info found in response_json
170+
st.session_state.reward_history.append(response_json["reward"])
171+
st.session_state.terminated_history.append(response_json["terminated"])
172+
st.session_state.truncated_history.append(response_json["truncated"])
173+
st.session_state.env_info_history.append(response_json["info"])
174+
160175

161176
def step_agent_history(action, action_info):
162177
logger.info("Stepping agent history")
@@ -185,6 +200,12 @@ def revert_env_history():
185200
st.session_state.screenshot_history.pop()
186201
st.session_state.axtree_history.pop()
187202

203+
# related to env info
204+
st.session_state.reward_history.pop()
205+
st.session_state.terminated_history.pop()
206+
st.session_state.truncated_history.pop()
207+
st.session_state.env_info_history.pop()
208+
188209

189210
def revert_agent_history():
190211
logger.info("Reverting agent history")
@@ -209,6 +230,12 @@ def restore_env_history(step: int):
209230
st.session_state.screenshot_history = copy.deepcopy(st.session_state.screenshot_history[:step])
210231
st.session_state.axtree_history = copy.deepcopy(st.session_state.axtree_history[:step])
211232

233+
# related to env info
234+
st.session_state.reward_history = copy.deepcopy(st.session_state.reward_history[:step])
235+
st.session_state.terminated_history = copy.deepcopy(st.session_state.terminated_history[:step])
236+
st.session_state.truncated_history = copy.deepcopy(st.session_state.truncated_history[:step])
237+
st.session_state.env_info_history = copy.deepcopy(st.session_state.env_info_history[:step])
238+
212239

213240
def restore_agent_history(step: int):
214241
logger.info(f"Restoring agent history to step {step}")
@@ -262,6 +289,8 @@ def set_session_state():
262289
st.session_state.task = None
263290
if "subtask" not in st.session_state:
264291
st.session_state.subtask = None
292+
if "env_args" not in st.session_state:
293+
st.session_state.env_args = None
265294

266295
# current state
267296
if "agent" not in st.session_state:
@@ -290,6 +319,14 @@ def set_session_state():
290319
st.session_state.action_info_history = None
291320
if "obs_history" not in st.session_state:
292321
st.session_state.obs_history = None
322+
if "reward_history" not in st.session_state:
323+
st.session_state.reward_history = None
324+
if "terminated_history" not in st.session_state:
325+
st.session_state.terminated_history = None
326+
if "truncated_history" not in st.session_state:
327+
st.session_state.truncated_history = None
328+
if "env_info_history" not in st.session_state:
329+
st.session_state.env_info_history = None
293330

294331
if "has_clicked_prev" not in st.session_state:
295332
st.session_state.has_clicked_prev = False
@@ -362,6 +399,13 @@ def set_task_selector():
362399
st.session_state.task = selected_task_str
363400
st.session_state.subtask = selected_subtask_str
364401

402+
st.session_state.env_args = [
403+
elem
404+
for elem in selected_benchmark.env_args_list
405+
if elem.task_name == selected_task_str
406+
and str(elem.task_seed) == str(selected_subtask_str)
407+
][0]
408+
365409
reset_env_history()
366410
reset_agent_history()
367411

@@ -423,11 +467,12 @@ def reset_environment():
423467
logger.error(resp.json()[Constants.STATUS])
424468
logger.error(resp.json()[Constants.MESSAGE])
425469
response_json = resp.json()
470+
print(response_json.keys())
426471
response_json = deserialize_response(response_json)
427472
obs = response_json[Constants.OBS]
428473
if st.session_state.agent.obs_preprocessor:
429474
obs = st.session_state.agent.obs_preprocessor(obs)
430-
step_env_history(obs)
475+
step_env_history(obs, response_json)
431476
st.session_state.action = None
432477
st.session_state.action_info = None
433478

@@ -447,7 +492,7 @@ def reload_task():
447492
obs = response_json[Constants.OBS]
448493
if st.session_state.agent.obs_preprocessor:
449494
obs = st.session_state.agent.obs_preprocessor(obs)
450-
step_env_history(obs)
495+
step_env_history(obs, response_json)
451496
st.session_state.action = None
452497
st.session_state.action_info = None
453498

@@ -468,7 +513,7 @@ def step_environment(action):
468513
obs = response_json[Constants.OBS]
469514
if st.session_state.agent.obs_preprocessor:
470515
obs = st.session_state.agent.obs_preprocessor(obs)
471-
step_env_history(obs)
516+
step_env_history(obs, response_json)
472517
st.session_state.action = None
473518
st.session_state.action_info = None
474519

@@ -880,44 +925,45 @@ def set_save_tab():
880925
save_dir = st.text_input("Save Directory", value="~/Downloads")
881926
save_dir = os.path.expanduser(save_dir)
882927
if st.button("Save Session State for Current Run"):
883-
now_str = datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
884-
filename = f"agentlab_controller_state_{now_str}.json"
885-
886-
# prepare payload for saving
887-
payload = {}
888-
payload["timestamp"] = now_str
889-
payload["benchmark"] = st.session_state.benchmark
890-
payload["task"] = st.session_state.task
891-
payload["subtask"] = st.session_state.subtask
892-
payload["agent_args"] = {
893-
k: v for k, v in vars(st.session_state.agent_args).items() if is_json_serializable(v)
894-
}
895-
payload["agent_flags"] = {
896-
k: v for k, v in vars(st.session_state.agent.flags).items() if is_json_serializable(v)
897-
}
898-
payload["agent_flags"]["obs"] = {
899-
k: v
900-
for k, v in vars(st.session_state.agent.flags.obs).items()
901-
if is_json_serializable(v)
902-
}
903-
payload["agent_flags"]["action"] = {
904-
k: v
905-
for k, v in vars(st.session_state.agent.flags.action).items()
906-
if is_json_serializable(v)
907-
}
908-
payload["goal"] = st.session_state.last_obs["goal"]
909-
payload["steps"] = []
928+
# save everything from the session in a way that is consistent
929+
# with how experiments are saved with AgentLab
930+
931+
# dir name has this format: 2025-07-14_16-46-47_tooluse-gpt-4-1-on-workarena-l1-task-name-sort
932+
exp_dir = (
933+
Path(save_dir)
934+
/ f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_genericagent_{st.session_state.agent_args.agent_name}_on_{st.session_state.benchmark}_{st.session_state.env_args.task_name}_{st.session_state.env_args.task_name}_{st.session_state.env_args.task_seed}"
935+
)
936+
exp_dir.mkdir(parents=True, exist_ok=True)
937+
938+
# save package versions
939+
save_package_versions(exp_dir)
940+
941+
# create ExpArgs object
942+
exp_args = ExpArgs(
943+
agent_args=st.session_state.agent_args, env_args=st.session_state.env_args
944+
)
945+
with open(exp_dir / "exp_args.pkl", "wb") as f:
946+
pickle.dump(exp_args, f)
947+
948+
# create StepInfo object for each step
910949
for i in range(len(st.session_state.action_history)):
911-
step = {}
912-
step["action"] = st.session_state.action_history[i]
913-
step["thought"] = st.session_state.thought_history[i]
914-
step["prompt"] = st.session_state.prompt_history[i]
915-
step["screenshot"] = get_base64_serialized_image(st.session_state.screenshot_history[i])
916-
step["axtree"] = st.session_state.axtree_history[i]
917-
payload["steps"].append(step)
918-
919-
with open(os.path.join(save_dir, filename), "w") as f:
920-
json.dump(payload, f)
950+
step_info = StepInfo()
951+
step_info.step = i
952+
step_info.obs = st.session_state.obs_history[i]
953+
step_info.reward = st.session_state.reward_history[i]
954+
step_info.terminated = st.session_state.terminated_history[i]
955+
step_info.truncated = st.session_state.truncated_history[i]
956+
step_info.action = st.session_state.action_history[i]
957+
step_info.agent_info = st.session_state.action_info_history[i]
958+
step_info.make_stats()
959+
# TODO: set profiling stats
960+
step_info.task_info = st.session_state.env_info_history[i].get("task_info", None)
961+
step_info.raw_reward = st.session_state.env_info_history[i].get(
962+
"RAW_REWARD_GLOBAL", None
963+
)
964+
step_info.save_step_info(exp_dir, save_screenshot=True, save_som=True)
965+
966+
st.success(f"Saved session state at {exp_dir}")
921967

922968

923969
def set_info_tabs():

src/agentlab/analyze/server.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ def status(self) -> dict:
234234
{
235235
"status": "success",
236236
"message": "Environment status retrieved successfully.",
237+
"obs": self.last_obs,
238+
"reward": 0,
239+
"terminated": False,
240+
"truncated": False,
237241
"info_set": self.info_set,
238242
"env_created": self.env is not None,
239243
}
@@ -318,6 +322,9 @@ def reload_task(self) -> dict:
318322
"status": "success",
319323
"message": "Task reloaded successfully.",
320324
"obs": self.last_obs,
325+
"reward": 0,
326+
"terminated": False,
327+
"truncated": False,
321328
"info": self.last_info,
322329
}
323330
)
@@ -356,6 +363,9 @@ def reset(self) -> dict:
356363
"status": "success",
357364
"message": "Environment reset successfully",
358365
"obs": self.last_obs,
366+
"reward": 0,
367+
"terminated": False,
368+
"truncated": False,
359369
"info": self.last_info,
360370
}
361371
)
@@ -413,6 +423,9 @@ def get_obs(self) -> dict:
413423
"status": "success",
414424
"message": "Observation retrieved successfully.",
415425
"obs": self.last_obs,
426+
"reward": 0,
427+
"terminated": False,
428+
"truncated": False,
416429
"info": self.last_info,
417430
}
418431
)

0 commit comments

Comments
 (0)