11import copy
2+ import os
3+ import json
4+ from datetime import datetime
25from typing import Any , Dict , List , Optional , Union
36
47import gym
@@ -28,7 +31,10 @@ class JerichoEnv(BaseEnv):
2831 - remove_stuck_actions (:obj:`bool`): Whether to remove actions that do not change the observation.
2932 - add_location_and_inventory (:obj:`bool`): Whether to include player location and inventory in the observation.
3033 - for_unizero (:obj:`bool`): If True, specify additional keys for unizero compatibility.
31-
34+ - save_replay (:obj:`bool`): If True, the interaction log of the entire episode will be saved.
35+ - save_replay_path (:obj:`str`): Path where interaction logs are saved.
36+ - collect_policy_mode (:obj:`str`): The strategy pattern used in data collection in the collect_episode_data method, including "human", "random" and "expert".
37+ - env_type (:obj:`str`): Type of environment.
3238 Attributes:
3339 - tokenizer (Optional[AutoTokenizer]): The tokenizer loaded from the pretrained model.
3440 """
@@ -43,6 +49,10 @@ class JerichoEnv(BaseEnv):
4349 'remove_stuck_actions' : False ,
4450 'add_location_and_inventory' : False ,
4551 'for_unizero' : False ,
52+ 'save_replay' : False ,
53+ 'save_replay_path' : None ,
54+ 'env_type' : "zork1" ,
55+ 'collect_policy_mode' : "random"
4656 }
4757
4858 def __init__ (self , cfg : Dict [str , Any ]) -> None :
@@ -61,6 +71,10 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
6171 self .game_path : str = self .cfg ['game_path' ]
6272 self .max_action_num : int = self .cfg ['max_action_num' ]
6373 self .max_seq_len : int = self .cfg ['max_seq_len' ]
74+ self .save_replay : bool = self .cfg ['save_replay' ]
75+ self .save_replay_path : str = self .cfg ['save_replay_path' ]
76+ self .collect_policy_mode : str = self .cfg ['collect_policy_mode' ]
77+ self .env_type : str = self .cfg ['env_type' ]
6478
6579 # Record the last observation and action for detecting stuck actions.
6680 self .last_observation : Optional [str ] = None
@@ -75,7 +89,6 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
7589 self .remove_stuck_actions : bool = self .cfg ['remove_stuck_actions' ]
7690 self .add_location_and_inventory : bool = self .cfg ['add_location_and_inventory' ]
7791 self .for_unizero : bool = self .cfg ['for_unizero' ]
78-
7992 # Initialize the tokenizer once (only in rank 0 process if distributed)
8093 if JerichoEnv .tokenizer is None :
8194 if self .rank == 0 :
@@ -94,6 +107,9 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
94107 self .episode_return : float = 0.0
95108 self .env_step : int = 0
96109 self .timestep : int = 0
110+ self .episode_history : Optional [List [Dict [str , Any ]]] = None
111+ self .walkthrough_actions : Optional [List [str ]] = None
112+
97113
98114 # Define observation, action, and reward spaces.
99115 self .observation_space : gym .spaces .Dict = gym .spaces .Dict ()
@@ -183,6 +199,8 @@ def reset(self, return_str: bool = False) -> Dict[str, Any]:
183199 self .episode_return = 0.0
184200 self .env_step = 0
185201 self .timestep = 0
202+ self .episode_history = []
203+ self .walkthrough_actions = self ._env .get_walkthrough ()
186204
187205 if self .remove_stuck_actions :
188206 self .last_observation = initial_observation
@@ -192,7 +210,17 @@ def reset(self, return_str: bool = False) -> Dict[str, Any]:
192210 self .world_size = get_world_size ()
193211 self .rank = get_rank ()
194212
195- return self .prepare_obs (initial_observation , return_str )
213+ processed_obs = self .prepare_obs (initial_observation , return_str )
214+
215+ self .episode_history .append ({
216+ 'timestep' : 0 ,
217+ 'obs' : processed_obs ['observation' ],
218+ 'act' : None ,
219+ 'done' : False ,
220+ 'info' : info
221+ })
222+
223+ return processed_obs
196224
197225 def seed (self , seed : int , dynamic_seed : bool = True ) -> None :
198226 """
@@ -287,12 +315,25 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) ->
287315 if self .env_step >= self .max_steps :
288316 done = True
289317
318+ if self .save_replay :
319+ self .episode_history .append ({
320+ 'timestep' : self .timestep ,
321+ 'obs' : processed_obs ['observation' ],
322+ 'act' : action_str ,
323+ 'reward' : reward .item () if isinstance (reward , np .ndarray ) else reward ,
324+ 'done' : done ,
325+ 'info' : info
326+ })
327+
290328 if done :
291329 print ('=' * 20 )
292330 print (f'rank { self .rank } one episode done!' )
293331 self .finished = True
294332 info ['eval_episode_return' ] = self .episode_return
295333
334+ if self .save_replay :
335+ self .save_episode_data ()
336+
296337 return BaseEnvTimestep (processed_obs , reward , done , info )
297338
298339 @staticmethod
@@ -330,7 +371,91 @@ def create_evaluator_env_cfg(cfg: Dict[str, Any]) -> List[Dict[str, Any]]:
330371 cfg ['is_collect' ] = False
331372 return [cfg for _ in range (evaluator_env_num )]
332373
374+ def save_episode_data (self ):
375+ """
376+ Overview:
377+ Save the full episode interaction history (self.episode_history) to a JSON file.
378+ """
379+ if self .save_replay_path is None :
380+ self .save_replay_path = './log'
381+ os .makedirs (self .save_replay_path , exist_ok = True )
382+
383+ timestamp = datetime .now ().strftime ("%m%d_%H%M" )
384+ filename = os .path .join (self .save_replay_path , f"episode_record_{ self .env_type } _{ self .collect_policy_mode } _{ timestamp } .json" )
385+
386+ info = self .episode_history [- 1 ]['info' ]
387+ if 'eval_episode_return' in info and isinstance (info ['eval_episode_return' ], np .ndarray ):
388+ info ['eval_episode_return' ] = info ['eval_episode_return' ].item ()
389+
390+ with open (filename , mode = "w" , encoding = "utf-8" ) as f :
391+ json .dump (self .episode_history , f , ensure_ascii = False )
392+
393+ def human_step (self , observation :str ) -> str :
394+ """
395+ Overview:
396+ Interactively receive an action from a human player via command line input.
397+
398+ Arguments:
399+ - observation (:obj:`str`): The current observation shown to the human.
400+
401+ Returns:
402+ - (:obj:`int`): The action index input by the user, converted to int.
403+ """
404+ print (f"[OBS]\n { observation } " )
405+ while True :
406+ try :
407+ action_id = int (input ('Please input the action id (the id starts from zero): ' ))
408+ return action_id
409+ except ValueError :
410+ print ("Invalid input. Please enter an integer action id." )
411+
412+ def random_step (self ) -> str :
413+ """
414+ Overview:
415+ Randomly select a valid action from the current valid action list.
416+
417+ Returns:
418+ - (:obj:`str`): A randomly selected action string from the available actions. If no actions are available, returns 'go' as a fallback.
419+ """
420+ if self ._action_list is not None and len (self ._action_list )> 0 :
421+ return np .random .choice (self ._action_list )
422+ else :
423+ print (
424+ f"rank { self .rank } , available actions list empty. Using default action 'go'."
425+ )
426+ return 'go'
333427
428+ def collect_episode_data (self ):
429+ """
430+ Overview:
431+ Run a single episode using the specified policy mode, and store the trajectory in self.episode_history.
432+ """
433+
434+ obs = self .reset (return_str = True )
435+
436+ done = False
437+ expert_step_count = 0
438+
439+ while not done :
440+ if self .collect_policy_mode == 'human' :
441+ action = self .human_step (obs ['observation' ])
442+ elif self .collect_policy_mode == 'random' :
443+ action = self .random_step ()
444+ elif self .collect_policy_mode == 'expert' :
445+ action = self .walkthrough_actions [expert_step_count ]
446+ expert_step_count += 1
447+ else :
448+ raise ValueError (f"Invalid collect_policy_mode: { self .collect_policy_mode } " )
449+
450+ obs , reward , done , info = self .step (action , return_str = True )
451+
452+ if self .collect_policy_mode == 'expert' and expert_step_count >= len (self .walkthrough_actions ):
453+ done = True
454+
455+ if done :
456+ info ['eval_episode_return' ] = self .episode_return
457+ break
458+
334459if __name__ == '__main__' :
335460 from easydict import EasyDict
336461
@@ -347,19 +472,13 @@ def create_evaluator_env_cfg(cfg: Dict[str, Any]) -> List[Dict[str, Any]]:
347472 for_unizero = False ,
348473 collector_env_num = 1 ,
349474 evaluator_env_num = 1 ,
475+ save_replay = True ,
476+ save_replay_path = None ,
477+ env_type = 'zork1' , # zork1, acorncourt, detective, omniquest
478+ collect_policy_mode = 'expert' # random, human, expert
350479 )
351480 )
352481 env = JerichoEnv (env_cfg )
353- obs = env .reset (return_str = True )
354- print (f'[OBS]:\n { obs ["observation" ]} ' )
355- while True :
356- try :
357- action_id = int (input ('Please input the action id: ' ))
358- except ValueError :
359- print ("Invalid input. Please enter an integer action id." )
360- continue
361- obs , reward , done , info = env .step (action_id , return_str = True )
362- print (f'[OBS]:\n { obs ["observation" ]} ' )
363- if done :
364- user_choice = input ('Would you like to RESTART, RESTORE a saved game, give the FULL score for that game or QUIT? ' )
365- break
482+ # Collect data for an episode according to collect_policy_mode
483+ env .collect_episode_data ()
484+ del env
0 commit comments