Skip to content

Commit 456679d

Browse files
feature(xjy): add save_replay and collect_episode_data option in jericho (#333)
* add save complete episode option and collect episode data option in jericho * commit only the changes to the jericho_env file * modify to standard format and comments --------- Co-authored-by: Listerrrr <[email protected]>
1 parent 370506f commit 456679d

File tree

1 file changed

+135
-16
lines changed

1 file changed

+135
-16
lines changed

zoo/jericho/envs/jericho_env.py

Lines changed: 135 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import copy
2+
import os
3+
import json
4+
from datetime import datetime
25
from typing import Any, Dict, List, Optional, Union
36

47
import gym
@@ -28,7 +31,10 @@ class JerichoEnv(BaseEnv):
2831
- remove_stuck_actions (:obj:`bool`): Whether to remove actions that do not change the observation.
2932
- add_location_and_inventory (:obj:`bool`): Whether to include player location and inventory in the observation.
3033
- for_unizero (:obj:`bool`): If True, specify additional keys for unizero compatibility.
31-
34+
- save_replay (:obj:`bool`): If True, the interaction log of the entire episode will be saved.
35+
- save_replay_path (:obj:`str`): Path where interaction logs are saved.
36+
- collect_policy_mode (:obj:`str`): The strategy pattern used in data collection in the collect_episode_data method, including "human", "random" and "expert".
37+
- env_type (:obj:`str`): Type of environment.
3238
Attributes:
3339
- tokenizer (Optional[AutoTokenizer]): The tokenizer loaded from the pretrained model.
3440
"""
@@ -43,6 +49,10 @@ class JerichoEnv(BaseEnv):
4349
'remove_stuck_actions': False,
4450
'add_location_and_inventory': False,
4551
'for_unizero': False,
52+
'save_replay': False,
53+
'save_replay_path': None,
54+
'env_type': "zork1",
55+
'collect_policy_mode': "random"
4656
}
4757

4858
def __init__(self, cfg: Dict[str, Any]) -> None:
@@ -61,6 +71,10 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
6171
self.game_path: str = self.cfg['game_path']
6272
self.max_action_num: int = self.cfg['max_action_num']
6373
self.max_seq_len: int = self.cfg['max_seq_len']
74+
self.save_replay: bool = self.cfg['save_replay']
75+
self.save_replay_path: str = self.cfg['save_replay_path']
76+
self.collect_policy_mode: str = self.cfg['collect_policy_mode']
77+
self.env_type: str = self.cfg['env_type']
6478

6579
# Record the last observation and action for detecting stuck actions.
6680
self.last_observation: Optional[str] = None
@@ -75,7 +89,6 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
7589
self.remove_stuck_actions: bool = self.cfg['remove_stuck_actions']
7690
self.add_location_and_inventory: bool = self.cfg['add_location_and_inventory']
7791
self.for_unizero: bool = self.cfg['for_unizero']
78-
7992
# Initialize the tokenizer once (only in rank 0 process if distributed)
8093
if JerichoEnv.tokenizer is None:
8194
if self.rank == 0:
@@ -94,6 +107,9 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
94107
self.episode_return: float = 0.0
95108
self.env_step: int = 0
96109
self.timestep: int = 0
110+
self.episode_history: Optional[List[Dict[str, Any]]] = None
111+
self.walkthrough_actions: Optional[List[str]] = None
112+
97113

98114
# Define observation, action, and reward spaces.
99115
self.observation_space: gym.spaces.Dict = gym.spaces.Dict()
@@ -183,6 +199,8 @@ def reset(self, return_str: bool = False) -> Dict[str, Any]:
183199
self.episode_return = 0.0
184200
self.env_step = 0
185201
self.timestep = 0
202+
self.episode_history = []
203+
self.walkthrough_actions = self._env.get_walkthrough()
186204

187205
if self.remove_stuck_actions:
188206
self.last_observation = initial_observation
@@ -192,7 +210,17 @@ def reset(self, return_str: bool = False) -> Dict[str, Any]:
192210
self.world_size = get_world_size()
193211
self.rank = get_rank()
194212

195-
return self.prepare_obs(initial_observation, return_str)
213+
processed_obs = self.prepare_obs(initial_observation, return_str)
214+
215+
self.episode_history.append({
216+
'timestep': 0,
217+
'obs': processed_obs['observation'],
218+
'act': None,
219+
'done': False,
220+
'info': info
221+
})
222+
223+
return processed_obs
196224

197225
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
198226
"""
@@ -287,12 +315,25 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) ->
287315
if self.env_step >= self.max_steps:
288316
done = True
289317

318+
if self.save_replay:
319+
self.episode_history.append({
320+
'timestep': self.timestep,
321+
'obs': processed_obs['observation'],
322+
'act': action_str,
323+
'reward': reward.item() if isinstance(reward, np.ndarray) else reward,
324+
'done': done,
325+
'info': info
326+
})
327+
290328
if done:
291329
print('=' * 20)
292330
print(f'rank {self.rank} one episode done!')
293331
self.finished = True
294332
info['eval_episode_return'] = self.episode_return
295333

334+
if self.save_replay:
335+
self.save_episode_data()
336+
296337
return BaseEnvTimestep(processed_obs, reward, done, info)
297338

298339
@staticmethod
@@ -330,7 +371,91 @@ def create_evaluator_env_cfg(cfg: Dict[str, Any]) -> List[Dict[str, Any]]:
330371
cfg['is_collect'] = False
331372
return [cfg for _ in range(evaluator_env_num)]
332373

374+
def save_episode_data(self):
375+
"""
376+
Overview:
377+
Save the full episode interaction history (self.episode_history) to a JSON file.
378+
"""
379+
if self.save_replay_path is None:
380+
self.save_replay_path = './log'
381+
os.makedirs(self.save_replay_path, exist_ok=True)
382+
383+
timestamp = datetime.now().strftime("%m%d_%H%M")
384+
filename = os.path.join(self.save_replay_path, f"episode_record_{self.env_type}_{self.collect_policy_mode}_{timestamp}.json")
385+
386+
info = self.episode_history[-1]['info']
387+
if 'eval_episode_return' in info and isinstance(info['eval_episode_return'], np.ndarray):
388+
info['eval_episode_return'] = info['eval_episode_return'].item()
389+
390+
with open(filename, mode="w", encoding="utf-8") as f:
391+
json.dump(self.episode_history, f, ensure_ascii=False)
392+
393+
def human_step(self, observation:str) -> str:
394+
"""
395+
Overview:
396+
Interactively receive an action from a human player via command line input.
397+
398+
Arguments:
399+
- observation (:obj:`str`): The current observation shown to the human.
400+
401+
Returns:
402+
- (:obj:`int`): The action index input by the user, converted to int.
403+
"""
404+
print(f"[OBS]\n{observation}")
405+
while True:
406+
try:
407+
action_id = int(input('Please input the action id (the id starts from zero): '))
408+
return action_id
409+
except ValueError:
410+
print("Invalid input. Please enter an integer action id.")
411+
412+
def random_step(self) -> str:
413+
"""
414+
Overview:
415+
Randomly select a valid action from the current valid action list.
416+
417+
Returns:
418+
- (:obj:`str`): A randomly selected action string from the available actions. If no actions are available, returns 'go' as a fallback.
419+
"""
420+
if self._action_list is not None and len(self._action_list)>0:
421+
return np.random.choice(self._action_list)
422+
else:
423+
print(
424+
f"rank {self.rank}, available actions list empty. Using default action 'go'."
425+
)
426+
return 'go'
333427

428+
def collect_episode_data(self):
429+
"""
430+
Overview:
431+
Run a single episode using the specified policy mode, and store the trajectory in self.episode_history.
432+
"""
433+
434+
obs = self.reset(return_str=True)
435+
436+
done = False
437+
expert_step_count = 0
438+
439+
while not done:
440+
if self.collect_policy_mode == 'human':
441+
action = self.human_step(obs['observation'])
442+
elif self.collect_policy_mode == 'random':
443+
action = self.random_step()
444+
elif self.collect_policy_mode == 'expert':
445+
action = self.walkthrough_actions[expert_step_count]
446+
expert_step_count += 1
447+
else:
448+
raise ValueError(f"Invalid collect_policy_mode: {self.collect_policy_mode}")
449+
450+
obs, reward, done, info = self.step(action, return_str=True)
451+
452+
if self.collect_policy_mode == 'expert' and expert_step_count >= len(self.walkthrough_actions):
453+
done = True
454+
455+
if done:
456+
info['eval_episode_return'] = self.episode_return
457+
break
458+
334459
if __name__ == '__main__':
335460
from easydict import EasyDict
336461

@@ -347,19 +472,13 @@ def create_evaluator_env_cfg(cfg: Dict[str, Any]) -> List[Dict[str, Any]]:
347472
for_unizero=False,
348473
collector_env_num=1,
349474
evaluator_env_num=1,
475+
save_replay=True,
476+
save_replay_path=None,
477+
env_type='zork1', # zork1, acorncourt, detective, omniquest
478+
collect_policy_mode='expert' # random, human, expert
350479
)
351480
)
352481
env = JerichoEnv(env_cfg)
353-
obs = env.reset(return_str=True)
354-
print(f'[OBS]:\n{obs["observation"]}')
355-
while True:
356-
try:
357-
action_id = int(input('Please input the action id: '))
358-
except ValueError:
359-
print("Invalid input. Please enter an integer action id.")
360-
continue
361-
obs, reward, done, info = env.step(action_id, return_str=True)
362-
print(f'[OBS]:\n{obs["observation"]}')
363-
if done:
364-
user_choice = input('Would you like to RESTART, RESTORE a saved game, give the FULL score for that game or QUIT? ')
365-
break
482+
# Collect data for an episode according to collect_policy_mode
483+
env.collect_episode_data()
484+
del env

0 commit comments

Comments
 (0)