2323from browsergym .experiments .utils import count_messages_token , count_tokens
2424from dataclasses_json import DataClassJsonMixin
2525from PIL import Image
26- from tapeagents .core import (
27- StepMetadata ,
28- Tape ,
29- )
26+ from tapeagents .core import StepMetadata , Tape
3027from tapeagents .dialog_tape import AssistantStep , AssistantThought
3128from tqdm import tqdm
3229
@@ -315,9 +312,8 @@ def run(self):
315312 err_msg = f"Exception uncaught by agent or environment in task { self .env_args .task_name } .\n { type (e ).__name__ } :\n { e } "
316313 logger .info ("Saving experiment info." )
317314 _save_summary_info (episode_info , self .exp_dir , err_msg , stack_trace )
318- self .save_tape (
319- agent .final_tape if isinstance (agent , TapeAgent ) else self .as_tape (episode_info )
320- )
315+ tape = agent .final_tape if isinstance (agent , TapeAgent ) else as_tape (episode_info )
316+ self .save_tape (tape )
321317 except Exception as e :
322318 logger .exception (f"Error while saving experiment info: { e } " )
323319 try :
@@ -330,36 +326,11 @@ def run(self):
330326 except Exception as e :
331327 logger .exception (f"Error while unsetting the logger: { e } " )
332328
333- def as_tape (self , steps_info : list ["StepInfo" ]) -> Tape :
334- """
335- Create a Tape object from the steps info.
336-
337- Returns:
338- Tape: a Tape object containing the steps and metadata.
339- """
340- tape : Tape = []
341- for step_info in steps_info :
342- step_metadata = StepMetadata (
343- result = dict (
344- reward = step_info .reward ,
345- raw_reward = step_info .raw_reward ,
346- terminated = step_info .terminated ,
347- truncated = step_info .truncated ,
348- agent_info = step_info .agent_info ,
349- stats = step_info .stats ,
350- )
351- )
352- steps = [DictObservation (content = step_info .obs )]
353- if thought := step_info .agent_info .get ("think" ):
354- steps .append (AssistantThought (content = thought ))
355- steps .append (AssistantStep (content = step_info .action , metadata = step_metadata ))
356- tape += steps
357- return tape
358-
359329 def save_tape (self , tape : Tape , filename : str = "tape.json" ):
360- if os .path .exists (self .exp_dir / filename ):
361- raise FileExistsError (f"{ filename } already exists in { self .exp_dir } " )
362- with open (self .exp_dir / filename , "w" ) as f :
330+ tape_path = Path (self .exp_dir ) / filename
331+ if tape_path .exists ():
332+ raise FileExistsError (f"{ tape_path } already exists" )
333+ with open (tape_path , "w" ) as f :
363334 json .dump (tape .model_dump (), f , indent = 2 , ensure_ascii = False )
364335
365336 def _set_logger (self ):
@@ -951,3 +922,31 @@ def _flatten_dict(d, parent_key="", sep="."):
951922 else :
952923 items .append ((new_key , v ))
953924 return dict (items )
925+
926+
927+ def as_tape (steps_info : list ) -> Tape :
928+ """
929+ Create a Tape object from the steps info.
930+
931+ Returns:
932+ Tape: a Tape object containing the steps and metadata.
933+ """
934+ tape : Tape = []
935+ for step_info in steps_info :
936+ step_metadata = StepMetadata (
937+ other = dict (
938+ reward = step_info .reward ,
939+ raw_reward = step_info .raw_reward ,
940+ terminated = step_info .terminated ,
941+ truncated = step_info .truncated ,
942+ agent_info = step_info .agent_info ,
943+ stats = step_info .stats ,
944+ )
945+ )
946+ steps = [DictObservation (content = step_info .obs )]
947+ if thought := step_info .agent_info .get ("think" ):
948+ steps .append (AssistantThought (content = thought ))
949+ if step_info .action is not None :
950+ steps .append (AssistantStep (content = step_info .action , metadata = step_metadata ))
951+ tape += steps
952+ return tape
0 commit comments