2323from browsergym .experiments .utils import count_messages_token , count_tokens
2424from dataclasses_json import DataClassJsonMixin
2525from PIL import Image
26- from tapeagents .core import StepMetadata , Tape
26+ from tapeagents .core import Step , StepMetadata , Tape
2727from tapeagents .dialog_tape import AssistantStep , AssistantThought
28+ from tapeagents .io import save_json_tape , save_tape_images
2829from tqdm import tqdm
2930
3031from agentlab .agents .tapeagent .agent import DictObservation , TapeAgent
@@ -312,8 +313,9 @@ def run(self):
312313 err_msg = f"Exception uncaught by agent or environment in task { self .env_args .task_name } .\n { type (e ).__name__ } :\n { e } "
313314 logger .info ("Saving experiment info." )
314315 _save_summary_info (episode_info , self .exp_dir , err_msg , stack_trace )
315- tape = agent .final_tape if isinstance (agent , TapeAgent ) else as_tape (episode_info )
316- self .save_tape (tape )
316+ if isinstance (agent , TapeAgent ):
317+ save_json_tape (agent .final_tape , self .exp_dir , "tape.json" )
318+ save_tape_images (agent .final_tape , self .exp_dir / "tape_attachments" )
317319 except Exception as e :
318320 logger .exception (f"Error while saving experiment info: { e } " )
319321 try :
@@ -326,13 +328,6 @@ def run(self):
326328 except Exception as e :
327329 logger .exception (f"Error while unsetting the logger: { e } " )
328330
329- def save_tape (self , tape : Tape , filename : str = "tape.json" ):
330- tape_path = Path (self .exp_dir ) / filename
331- if tape_path .exists ():
332- raise FileExistsError (f"{ tape_path } already exists" )
333- with open (tape_path , "w" ) as f :
334- json .dump (tape .model_dump (), f , indent = 2 , ensure_ascii = False )
335-
336331 def _set_logger (self ):
337332 # output logging traces to a log file
338333 file_handler = logging .FileHandler (self .exp_dir / "experiment.log" )
@@ -934,23 +929,28 @@ def as_tape(steps_info: list[StepInfo]) -> Tape:
934929 Returns:
935930 Tape: a Tape object containing the steps and metadata.
936931 """
937- tape : Tape = []
932+ steps : list [ Step ] = []
938933 for step_info in steps_info :
939- step_metadata = StepMetadata (
940- other = dict (
941- reward = step_info .reward ,
942- raw_reward = step_info .raw_reward ,
943- terminated = step_info .terminated ,
944- truncated = step_info .truncated ,
945- agent_info = step_info .agent_info ,
946- stats = step_info .stats ,
947- )
948- )
949934 if step_info .obs is not None :
950- steps = [DictObservation (content = step_info .obs )]
935+ try :
936+ obs_json = json .dumps (step_info .obs , cls = DataclassJSONEncoder )
937+ except Exception as e :
938+ logger .warning (f"Error while converting observation to JSON: { e } " )
939+ logger .warning (f"Observation: { step_info .obs } " )
940+ raise e
941+ steps .append (DictObservation (content = obs_json ))
951942 if thought := step_info .agent_info .get ("think" ):
952943 steps .append (AssistantThought (content = thought ))
953944 if step_info .action is not None :
945+ step_metadata = StepMetadata (
946+ other = dict (
947+ reward = step_info .reward ,
948+ raw_reward = step_info .raw_reward ,
949+ terminated = step_info .terminated ,
950+ truncated = step_info .truncated ,
951+ agent_info = step_info .agent_info ,
952+ stats = step_info .stats ,
953+ )
954+ )
954955 steps .append (AssistantStep (content = step_info .action , metadata = step_metadata ))
955- tape += steps
956- return tape
956+ return Tape (steps = steps )
0 commit comments