|
| 1 | +import os |
| 2 | +import argparse |
| 3 | +from typing import Any, Dict |
| 4 | + |
| 5 | +from easydict import EasyDict |
| 6 | + |
| 7 | + |
| 8 | +def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e5)) -> None: |
| 9 | + """ |
| 10 | + Main entry point for setting up environment configurations and launching training. |
| 11 | +
|
| 12 | + Args: |
| 13 | + env_id (str): Identifier of the environment, e.g., 'detective.z5'. |
| 14 | + seed (int): Random seed used for reproducibility. |
| 15 | +
|
| 16 | + Returns: |
| 17 | + None |
| 18 | + """ |
| 19 | + gpu_num = 4 |
| 20 | + collector_env_num: int = 4 # Number of collector environments |
| 21 | + n_episode = int(collector_env_num*gpu_num) |
| 22 | + batch_size = int(64*gpu_num) |
| 23 | + |
| 24 | + # ------------------------------------------------------------------ |
| 25 | + # Base environment parameters (Note: these values might be adjusted for different env_id) |
| 26 | + # ------------------------------------------------------------------ |
| 27 | + # Define environment configurations |
| 28 | + env_configurations = { |
| 29 | + 'detective.z5': (10, 50), |
| 30 | + 'omniquest.z5': (10, 100), |
| 31 | + 'acorncourt.z5': (10, 50), |
| 32 | + 'zork1.z5': (10, 400), |
| 33 | + } |
| 34 | + |
| 35 | + # env_id = 'detective.z5' |
| 36 | + # env_id = 'omniquest.z5' |
| 37 | + # env_id = 'acorncourt.z5' |
| 38 | + # env_id = 'zork1.z5' |
| 39 | + |
| 40 | + # Set action_space_size and max_steps based on env_id |
| 41 | + action_space_size, max_steps = env_configurations.get(env_id, (10, 50)) # Default values if env_id not found |
| 42 | + |
| 43 | + # ------------------------------------------------------------------ |
| 44 | + # User frequently modified configurations |
| 45 | + # ------------------------------------------------------------------ |
| 46 | + evaluator_env_num: int = 2 # Number of evaluator environments |
| 47 | + num_simulations: int = 50 # Number of simulations |
| 48 | + |
| 49 | + # Project training parameters |
| 50 | + num_unroll_steps: int = 10 # Number of unroll steps (for rollout sequence expansion) |
| 51 | + infer_context_length: int = 4 # Inference context length |
| 52 | + num_layers: int = 2 # Number of layers in the model |
| 53 | + replay_ratio: float = 0.25 # Replay ratio for experience replay |
| 54 | + embed_dim: int = 768 # Embedding dimension |
| 55 | + |
| 56 | + # Reanalysis (reanalyze) parameters: |
| 57 | + # buffer_reanalyze_freq: Frequency of reanalysis (e.g., 1 means reanalyze once per epoch) |
| 58 | + buffer_reanalyze_freq: float = 1 / 100000 |
| 59 | + # reanalyze_batch_size: Number of sequences to reanalyze per reanalysis process |
| 60 | + reanalyze_batch_size: int = 160 |
| 61 | + # reanalyze_partition: Partition ratio from the replay buffer to use during reanalysis |
| 62 | + reanalyze_partition: float = 0.75 |
| 63 | + |
| 64 | + # Model name or path - configurable according to the predefined model paths or names |
| 65 | + model_name: str = 'BAAI/bge-base-en-v1.5' |
| 66 | + |
| 67 | + # ------------------------------------------------------------------ |
| 68 | + # TODO: Debug configuration - override some parameters for debugging purposes |
| 69 | + # ------------------------------------------------------------------ |
| 70 | + # max_env_step = int(2e5) |
| 71 | + # batch_size = 10 |
| 72 | + # num_simulations = 2 |
| 73 | + # num_unroll_steps = 5 |
| 74 | + # infer_context_length = 2 |
| 75 | + # max_steps = 10 |
| 76 | + # num_layers = 1 |
| 77 | + # replay_ratio = 0.05 |
| 78 | + # ------------------------------------------------------------------ |
| 79 | + # Configuration dictionary for the Jericho Unizero environment and policy |
| 80 | + # ------------------------------------------------------------------ |
| 81 | + jericho_unizero_config: Dict[str, Any] = dict( |
| 82 | + env=dict( |
| 83 | + stop_value=int(1e6), |
| 84 | + observation_shape=512, |
| 85 | + max_steps=max_steps, |
| 86 | + max_action_num=action_space_size, |
| 87 | + tokenizer_path=model_name, |
| 88 | + max_seq_len=512, |
| 89 | + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", |
| 90 | + for_unizero=True, |
| 91 | + collector_env_num=collector_env_num, |
| 92 | + evaluator_env_num=evaluator_env_num, |
| 93 | + n_evaluator_episode=evaluator_env_num, |
| 94 | + manager=dict(shared_memory=False), |
| 95 | + ), |
| 96 | + policy=dict( |
| 97 | + multi_gpu=True, # Important for distributed data parallel (DDP) |
| 98 | + use_wandb=False, |
| 99 | + learn=dict( |
| 100 | + learner=dict( |
| 101 | + hook=dict( |
| 102 | + save_ckpt_after_iter=1000000, |
| 103 | + ), |
| 104 | + ), |
| 105 | + ), |
| 106 | + accumulation_steps=1, # TODO: Accumulated gradient steps (currently default) |
| 107 | + model=dict( |
| 108 | + observation_shape=512, |
| 109 | + action_space_size=action_space_size, |
| 110 | + encoder_url=model_name, |
| 111 | + model_type="mlp", |
| 112 | + continuous_action_space=False, |
| 113 | + world_model_cfg=dict( |
| 114 | + policy_entropy_weight=5e-2, |
| 115 | + continuous_action_space=False, |
| 116 | + max_blocks=num_unroll_steps, |
| 117 | + # Note: Each timestep contains 2 tokens: observation and action. |
| 118 | + max_tokens=2 * num_unroll_steps, |
| 119 | + context_length=2 * infer_context_length, |
| 120 | + device="cuda", |
| 121 | + action_space_size=action_space_size, |
| 122 | + num_layers=num_layers, |
| 123 | + num_heads=24, |
| 124 | + embed_dim=embed_dim, |
| 125 | + obs_type="text", # TODO: Modify as needed. |
| 126 | + env_num=max(collector_env_num, evaluator_env_num), |
| 127 | + ), |
| 128 | + ), |
| 129 | + update_per_collect=int(collector_env_num*max_steps*replay_ratio), # Important for DDP |
| 130 | + action_type="varied_action_space", |
| 131 | + model_path=None, |
| 132 | + num_unroll_steps=num_unroll_steps, |
| 133 | + reanalyze_ratio=0, |
| 134 | + replay_ratio=replay_ratio, |
| 135 | + batch_size=batch_size, |
| 136 | + learning_rate=0.0001, |
| 137 | + cos_lr_scheduler=True, |
| 138 | + fixed_temperature_value=0.25, |
| 139 | + manual_temperature_decay=False, |
| 140 | + num_simulations=num_simulations, |
| 141 | + n_episode=n_episode, |
| 142 | + train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed. |
| 143 | + replay_buffer_size=int(5e5), |
| 144 | + eval_freq=int(1e4), |
| 145 | + collector_env_num=collector_env_num, |
| 146 | + evaluator_env_num=evaluator_env_num, |
| 147 | + # Reanalysis key parameters: |
| 148 | + buffer_reanalyze_freq=buffer_reanalyze_freq, |
| 149 | + reanalyze_batch_size=reanalyze_batch_size, |
| 150 | + reanalyze_partition=reanalyze_partition, |
| 151 | + ), |
| 152 | + ) |
| 153 | + jericho_unizero_config = EasyDict(jericho_unizero_config) |
| 154 | + |
| 155 | + # ------------------------------------------------------------------ |
| 156 | + # Create configuration for importing environment and policy modules |
| 157 | + # ------------------------------------------------------------------ |
| 158 | + jericho_unizero_create_config: Dict[str, Any] = dict( |
| 159 | + env=dict( |
| 160 | + type="jericho", |
| 161 | + import_names=["zoo.jericho.envs.jericho_env"], |
| 162 | + ), |
| 163 | + # Use base env manager to avoid bugs present in subprocess env manager. |
| 164 | + env_manager=dict(type="base"), |
| 165 | + # If necessary, switch to subprocess env manager by uncommenting the following line: |
| 166 | + # env_manager=dict(type="subprocess"), |
| 167 | + policy=dict( |
| 168 | + type="unizero", |
| 169 | + import_names=["lzero.policy.unizero"], |
| 170 | + ), |
| 171 | + ) |
| 172 | + jericho_unizero_create_config = EasyDict(jericho_unizero_create_config) |
| 173 | + |
| 174 | + # ------------------------------------------------------------------ |
| 175 | + # Combine configuration dictionaries and construct an experiment name |
| 176 | + # ------------------------------------------------------------------ |
| 177 | + main_config: EasyDict = jericho_unizero_config |
| 178 | + create_config: EasyDict = jericho_unizero_create_config |
| 179 | + |
| 180 | + from ding.utils import DDPContext |
| 181 | + from lzero.config.utils import lz_to_ddp_config |
| 182 | + with DDPContext(): |
| 183 | + main_config = lz_to_ddp_config(main_config) |
| 184 | + # Construct experiment name containing key parameters |
| 185 | + main_config.exp_name = ( |
| 186 | + f"data_lz/data_unizero_jericho/bge-base-en-v1.5/uz_ddp-{gpu_num}gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" |
| 187 | + f"nlayer{num_layers}_embed{embed_dim}_Htrain{num_unroll_steps}-" |
| 188 | + f"Hinfer{infer_context_length}_bs{batch_size}_seed{seed}" |
| 189 | + ) |
| 190 | + from lzero.entry import train_unizero |
| 191 | + # Launch the training process |
| 192 | + train_unizero( |
| 193 | + [main_config, create_config], |
| 194 | + seed=seed, |
| 195 | + model_path=main_config.policy.model_path, |
| 196 | + max_env_step=max_env_step, |
| 197 | + ) |
| 198 | + |
| 199 | + |
| 200 | +if __name__ == "__main__": |
| 201 | + """ |
| 202 | + Overview: |
| 203 | + This script should be executed with <nproc_per_node> GPUs. |
| 204 | + Run the following command to launch the script: |
| 205 | + torchrun --nproc_per_node=4 ./zoo/jericho/configs/jericho_unizero_ddp_config.py |
| 206 | + """ |
| 207 | + |
| 208 | + parser = argparse.ArgumentParser(description='Process environment configuration and launch training.') |
| 209 | + parser.add_argument( |
| 210 | + '--env', |
| 211 | + type=str, |
| 212 | + help='Identifier of the environment, e.g., detective.z5 or zork1.z5', |
| 213 | + default='detective.z5' |
| 214 | + ) |
| 215 | + parser.add_argument( |
| 216 | + '--seed', |
| 217 | + type=int, |
| 218 | + help='Random seed for reproducibility', |
| 219 | + default=0 |
| 220 | + ) |
| 221 | + args = parser.parse_args() |
| 222 | + |
| 223 | + # Disable tokenizer parallelism to prevent multi-process conflicts |
| 224 | + os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
| 225 | + |
| 226 | + # Start the main process with the provided arguments |
| 227 | + main(args.env, args.seed) |
0 commit comments