Skip to content

Commit cdde886

Browse files
puyuan1996puyuan
andauthored
feature(pu): add jericho ddp config (#337)
* feature(pu): add jericho ddp config * polish(pu): polish jericho ddp config * polish(pu): polish jericho ddp configs --------- Co-authored-by: puyuan <[email protected]>
1 parent 28d5505 commit cdde886

File tree

4 files changed

+246
-5
lines changed

4 files changed

+246
-5
lines changed

zoo/jericho/configs/jericho_unizero_config.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from easydict import EasyDict
66

77

8-
def main(env_id: str = 'detective.z5', seed: int = 0) -> None:
8+
def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e5)) -> None:
99
"""
1010
Main entry point for setting up environment configurations and launching training.
1111
@@ -40,7 +40,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None:
4040
# ------------------------------------------------------------------
4141
evaluator_env_num: int = 2 # Number of evaluator environments
4242
num_simulations: int = 50 # Number of simulations
43-
max_env_step: int = int(1e6) # Maximum environment steps
4443

4544
# Project training parameters
4645
collector_env_num: int = 4 # Number of collector environments
@@ -137,8 +136,8 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None:
137136
batch_size=batch_size,
138137
learning_rate=0.0001,
139138
cos_lr_scheduler=True,
140-
manual_temperature_decay=True,
141-
threshold_training_steps_for_final_temperature=int(2.5e4),
139+
fixed_temperature_value=0.25,
140+
manual_temperature_decay=False,
142141
num_simulations=num_simulations,
143142
n_episode=n_episode,
144143
train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed.
@@ -216,4 +215,10 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None:
216215
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
217216

218217
# Start the main process with the provided arguments
219-
main(args.env, args.seed)
218+
main(args.env, args.seed)
219+
220+
# ====== the following is only for cprofile ======
221+
# def run(max_env_step: int):
222+
# main(args.env, args.seed, max_env_step=max_env_step)
223+
# import cProfile
224+
# cProfile.run(f"run({10000})", filename="./zoo/jericho/detective_unizero_cprofile_10k_envstep", sort="cumulative")
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import os
2+
import argparse
3+
from typing import Any, Dict
4+
5+
from easydict import EasyDict
6+
7+
8+
def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e5)) -> None:
9+
"""
10+
Main entry point for setting up environment configurations and launching training.
11+
12+
Args:
13+
env_id (str): Identifier of the environment, e.g., 'detective.z5'.
14+
seed (int): Random seed used for reproducibility.
15+
16+
Returns:
17+
None
18+
"""
19+
gpu_num = 4
20+
collector_env_num: int = 4 # Number of collector environments
21+
n_episode = int(collector_env_num*gpu_num)
22+
batch_size = int(64*gpu_num)
23+
24+
# ------------------------------------------------------------------
25+
# Base environment parameters (Note: these values might be adjusted for different env_id)
26+
# ------------------------------------------------------------------
27+
# Define environment configurations
28+
env_configurations = {
29+
'detective.z5': (10, 50),
30+
'omniquest.z5': (10, 100),
31+
'acorncourt.z5': (10, 50),
32+
'zork1.z5': (10, 400),
33+
}
34+
35+
# env_id = 'detective.z5'
36+
# env_id = 'omniquest.z5'
37+
# env_id = 'acorncourt.z5'
38+
# env_id = 'zork1.z5'
39+
40+
# Set action_space_size and max_steps based on env_id
41+
action_space_size, max_steps = env_configurations.get(env_id, (10, 50)) # Default values if env_id not found
42+
43+
# ------------------------------------------------------------------
44+
# User frequently modified configurations
45+
# ------------------------------------------------------------------
46+
evaluator_env_num: int = 2 # Number of evaluator environments
47+
num_simulations: int = 50 # Number of simulations
48+
49+
# Project training parameters
50+
num_unroll_steps: int = 10 # Number of unroll steps (for rollout sequence expansion)
51+
infer_context_length: int = 4 # Inference context length
52+
num_layers: int = 2 # Number of layers in the model
53+
replay_ratio: float = 0.25 # Replay ratio for experience replay
54+
embed_dim: int = 768 # Embedding dimension
55+
56+
# Reanalysis (reanalyze) parameters:
57+
# buffer_reanalyze_freq: Frequency of reanalysis (e.g., 1 means reanalyze once per epoch)
58+
buffer_reanalyze_freq: float = 1 / 100000
59+
# reanalyze_batch_size: Number of sequences to reanalyze per reanalysis process
60+
reanalyze_batch_size: int = 160
61+
# reanalyze_partition: Partition ratio from the replay buffer to use during reanalysis
62+
reanalyze_partition: float = 0.75
63+
64+
# Model name or path - configurable according to the predefined model paths or names
65+
model_name: str = 'BAAI/bge-base-en-v1.5'
66+
67+
# ------------------------------------------------------------------
68+
# TODO: Debug configuration - override some parameters for debugging purposes
69+
# ------------------------------------------------------------------
70+
# max_env_step = int(2e5)
71+
# batch_size = 10
72+
# num_simulations = 2
73+
# num_unroll_steps = 5
74+
# infer_context_length = 2
75+
# max_steps = 10
76+
# num_layers = 1
77+
# replay_ratio = 0.05
78+
# ------------------------------------------------------------------
79+
# Configuration dictionary for the Jericho Unizero environment and policy
80+
# ------------------------------------------------------------------
81+
jericho_unizero_config: Dict[str, Any] = dict(
82+
env=dict(
83+
stop_value=int(1e6),
84+
observation_shape=512,
85+
max_steps=max_steps,
86+
max_action_num=action_space_size,
87+
tokenizer_path=model_name,
88+
max_seq_len=512,
89+
game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}",
90+
for_unizero=True,
91+
collector_env_num=collector_env_num,
92+
evaluator_env_num=evaluator_env_num,
93+
n_evaluator_episode=evaluator_env_num,
94+
manager=dict(shared_memory=False),
95+
),
96+
policy=dict(
97+
multi_gpu=True, # Important for distributed data parallel (DDP)
98+
use_wandb=False,
99+
learn=dict(
100+
learner=dict(
101+
hook=dict(
102+
save_ckpt_after_iter=1000000,
103+
),
104+
),
105+
),
106+
accumulation_steps=1, # TODO: Accumulated gradient steps (currently default)
107+
model=dict(
108+
observation_shape=512,
109+
action_space_size=action_space_size,
110+
encoder_url=model_name,
111+
model_type="mlp",
112+
continuous_action_space=False,
113+
world_model_cfg=dict(
114+
policy_entropy_weight=5e-2,
115+
continuous_action_space=False,
116+
max_blocks=num_unroll_steps,
117+
# Note: Each timestep contains 2 tokens: observation and action.
118+
max_tokens=2 * num_unroll_steps,
119+
context_length=2 * infer_context_length,
120+
device="cuda",
121+
action_space_size=action_space_size,
122+
num_layers=num_layers,
123+
num_heads=24,
124+
embed_dim=embed_dim,
125+
obs_type="text", # TODO: Modify as needed.
126+
env_num=max(collector_env_num, evaluator_env_num),
127+
),
128+
),
129+
update_per_collect=int(collector_env_num*max_steps*replay_ratio), # Important for DDP
130+
action_type="varied_action_space",
131+
model_path=None,
132+
num_unroll_steps=num_unroll_steps,
133+
reanalyze_ratio=0,
134+
replay_ratio=replay_ratio,
135+
batch_size=batch_size,
136+
learning_rate=0.0001,
137+
cos_lr_scheduler=True,
138+
fixed_temperature_value=0.25,
139+
manual_temperature_decay=False,
140+
num_simulations=num_simulations,
141+
n_episode=n_episode,
142+
train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed.
143+
replay_buffer_size=int(5e5),
144+
eval_freq=int(1e4),
145+
collector_env_num=collector_env_num,
146+
evaluator_env_num=evaluator_env_num,
147+
# Reanalysis key parameters:
148+
buffer_reanalyze_freq=buffer_reanalyze_freq,
149+
reanalyze_batch_size=reanalyze_batch_size,
150+
reanalyze_partition=reanalyze_partition,
151+
),
152+
)
153+
jericho_unizero_config = EasyDict(jericho_unizero_config)
154+
155+
# ------------------------------------------------------------------
156+
# Create configuration for importing environment and policy modules
157+
# ------------------------------------------------------------------
158+
jericho_unizero_create_config: Dict[str, Any] = dict(
159+
env=dict(
160+
type="jericho",
161+
import_names=["zoo.jericho.envs.jericho_env"],
162+
),
163+
# Use base env manager to avoid bugs present in subprocess env manager.
164+
env_manager=dict(type="base"),
165+
# If necessary, switch to subprocess env manager by uncommenting the following line:
166+
# env_manager=dict(type="subprocess"),
167+
policy=dict(
168+
type="unizero",
169+
import_names=["lzero.policy.unizero"],
170+
),
171+
)
172+
jericho_unizero_create_config = EasyDict(jericho_unizero_create_config)
173+
174+
# ------------------------------------------------------------------
175+
# Combine configuration dictionaries and construct an experiment name
176+
# ------------------------------------------------------------------
177+
main_config: EasyDict = jericho_unizero_config
178+
create_config: EasyDict = jericho_unizero_create_config
179+
180+
from ding.utils import DDPContext
181+
from lzero.config.utils import lz_to_ddp_config
182+
with DDPContext():
183+
main_config = lz_to_ddp_config(main_config)
184+
# Construct experiment name containing key parameters
185+
main_config.exp_name = (
186+
f"data_lz/data_unizero_jericho/bge-base-en-v1.5/uz_ddp-{gpu_num}gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_"
187+
f"nlayer{num_layers}_embed{embed_dim}_Htrain{num_unroll_steps}-"
188+
f"Hinfer{infer_context_length}_bs{batch_size}_seed{seed}"
189+
)
190+
from lzero.entry import train_unizero
191+
# Launch the training process
192+
train_unizero(
193+
[main_config, create_config],
194+
seed=seed,
195+
model_path=main_config.policy.model_path,
196+
max_env_step=max_env_step,
197+
)
198+
199+
200+
if __name__ == "__main__":
201+
"""
202+
Overview:
203+
This script should be executed with <nproc_per_node> GPUs.
204+
Run the following command to launch the script:
205+
torchrun --nproc_per_node=4 ./zoo/jericho/configs/jericho_unizero_ddp_config.py
206+
"""
207+
208+
parser = argparse.ArgumentParser(description='Process environment configuration and launch training.')
209+
parser.add_argument(
210+
'--env',
211+
type=str,
212+
help='Identifier of the environment, e.g., detective.z5 or zork1.z5',
213+
default='detective.z5'
214+
)
215+
parser.add_argument(
216+
'--seed',
217+
type=int,
218+
help='Random seed for reproducibility',
219+
default=0
220+
)
221+
args = parser.parse_args()
222+
223+
# Disable tokenizer parallelism to prevent multi-process conflicts
224+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
225+
226+
# Start the main process with the provided arguments
227+
main(args.env, args.seed)

zoo/jericho/configs/jericho_unizero_segment_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None:
110110
replay_ratio=replay_ratio,
111111
batch_size=batch_size,
112112
learning_rate=0.0001,
113+
fixed_temperature_value=0.25,
114+
manual_temperature_decay=False,
113115
num_simulations=num_simulations,
114116
num_segments=num_segments,
115117
train_start_after_envsteps=0,

zoo/jericho/envs/jericho_env.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import copy
23
import os
34
import json
@@ -89,6 +90,7 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
8990
self.remove_stuck_actions: bool = self.cfg['remove_stuck_actions']
9091
self.add_location_and_inventory: bool = self.cfg['add_location_and_inventory']
9192
self.for_unizero: bool = self.cfg['for_unizero']
93+
9294
# Initialize the tokenizer once (only in rank 0 process if distributed)
9395
if JerichoEnv.tokenizer is None:
9496
if self.rank == 0:
@@ -389,6 +391,11 @@ def save_episode_data(self):
389391

390392
with open(filename, mode="w", encoding="utf-8") as f:
391393
json.dump(self.episode_history, f, ensure_ascii=False)
394+
logging.info(
395+
f"Episode data successfully saved to '{filename}'. "
396+
f"Episode length: {len(self.episode_history)} interactions, "
397+
f"Environment type: {self.env_type}, Policy mode: {self.collect_policy_mode}."
398+
)
392399

393400
def human_step(self, observation:str) -> str:
394401
"""

0 commit comments

Comments
 (0)