2121from lzero .worker import MuZeroEvaluator as Evaluator
2222from lzero .worker import MuZeroCollector as Collector
2323from .utils import random_collect , calculate_update_per_collect
24+ import torch .distributed as dist
25+ from ding .utils import set_pkg_seed , get_rank , get_world_size
2426
2527
2628def train_unizero (
@@ -33,167 +35,201 @@ def train_unizero(
3335) -> 'Policy' :
3436 """
3537 Overview:
36- The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models.
38+ This function serves as the training entry point for UniZero, as proposed in our paper " UniZero: Generalized and Efficient Planning with Scalable Latent World Models" .
3739 UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
38- particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
40+ particularly in environments that require capturing long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
41+
3942 Arguments:
40- - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
41- ``Tuple[dict, dict]`` type means [user_config, create_cfg].
42- - seed (:obj:`int`): Random seed.
43- - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
44- - model_path (:obj:`Optional[str]`): The pretrained model path, which should
45- point to the ckpt file of the pretrained model, and an absolute path is recommended.
46- In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
47- - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
48- - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
43+ - input_cfg (:obj:`Tuple[dict, dict]`): Configuration in dictionary format.
44+ ``Tuple[dict, dict]`` indicates [user_config, create_cfg].
45+ - seed (:obj:`int`): Random seed for reproducibility.
46+ - model (:obj:`Optional[torch.nn.Module]`): Instance of a PyTorch model.
47+ - model_path (:obj:`Optional[str]`): Path to the pretrained model, which should
48+ point to the checkpoint file of the pretrained model. An absolute path is recommended.
49+ In LightZero, the path typically resembles ``exp_name/ckpt/ckpt_best.pth.tar``.
50+ - max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations during training.
51+ - max_env_step (:obj:`Optional[int]`): Maximum number of environment interaction steps to collect.
52+
4953 Returns:
50- - policy (:obj:`Policy`): Converged policy.
54+ - policy (:obj:`Policy`): The converged policy after training .
5155 """
5256
5357 cfg , create_cfg = input_cfg
5458
5559 # Ensure the specified policy type is supported
56- assert create_cfg .policy .type in ['unizero' , 'sampled_unizero' ], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"
60+ assert create_cfg .policy .type in ['unizero' , 'sampled_unizero' ], "train_unizero only supports the following algorithms: 'unizero', 'sampled_unizero'"
61+ logging .info (f"Using policy type: { create_cfg .policy .type } " )
5762
58- # Import the correct GameBuffer class based on the policy type
63+ # Import the appropriate GameBuffer class based on the policy type
5964 game_buffer_classes = {'unizero' : 'UniZeroGameBuffer' , 'sampled_unizero' : 'SampledUniZeroGameBuffer' }
60-
6165 GameBuffer = getattr (__import__ ('lzero.mcts' , fromlist = [game_buffer_classes [create_cfg .policy .type ]]),
6266 game_buffer_classes [create_cfg .policy .type ])
6367
64- # Set device based on CUDA availability
68+ # Check for GPU availability and set the device accordingly
6569 cfg .policy .device = cfg .policy .model .world_model_cfg .device if torch .cuda .is_available () else 'cpu'
66- logging .info (f'cfg.policy.device : { cfg .policy .device } ' )
70+ logging .info (f"Device set to : { cfg .policy .device } " )
6771
68- # Compile the configuration
72+ # Compile the configuration file
6973 cfg = compile_config (cfg , seed = seed , env = None , auto = True , create_cfg = create_cfg , save_cfg = True )
7074
71- # Create main components: env, policy
75+ # Create environment manager
7276 env_fn , collector_env_cfg , evaluator_env_cfg = get_vec_env_setting (cfg .env )
7377 collector_env = create_env_manager (cfg .env .manager , [partial (env_fn , cfg = c ) for c in collector_env_cfg ])
7478 evaluator_env = create_env_manager (cfg .env .manager , [partial (env_fn , cfg = c ) for c in evaluator_env_cfg ])
7579
80+ # Initialize environment and random seed
7681 collector_env .seed (cfg .seed )
7782 evaluator_env .seed (cfg .seed , dynamic_seed = False )
7883 set_pkg_seed (cfg .seed , use_cuda = torch .cuda .is_available ())
7984
85+ # Initialize wandb if specified
8086 if cfg .policy .use_wandb :
81- # Initialize wandb
87+ logging . info ( "Initializing wandb..." )
8288 wandb .init (
8389 project = "LightZero" ,
8490 config = cfg ,
8591 sync_tensorboard = False ,
8692 monitor_gym = False ,
8793 save_code = True ,
8894 )
95+ logging .info ("wandb initialization completed!" )
8996
97+ # Create policy
98+ logging .info ("Creating policy..." )
9099 policy = create_policy (cfg .policy , model = model , enable_field = ['learn' , 'collect' , 'eval' ])
100+ logging .info ("Policy created successfully!" )
91101
92102 # Load pretrained model if specified
93103 if model_path is not None :
94- logging .info (f' Loading model from { model_path } begin ...' )
104+ logging .info (f" Loading pretrained model from { model_path } ..." )
95105 policy .learn_mode .load_state_dict (torch .load (model_path , map_location = cfg .policy .device ))
96- logging .info (f'Loading model from { model_path } end!' )
106+ logging .info ("Pretrained model loaded successfully!" )
97107
98- # Create worker components: learner, collector, evaluator, replay buffer, commander
108+ # Create core components for training
99109 tb_logger = SummaryWriter (os .path .join ('./{}/log/' .format (cfg .exp_name ), 'serial' )) if get_rank () == 0 else None
100110 learner = BaseLearner (cfg .policy .learn .learner , policy .learn_mode , tb_logger , exp_name = cfg .exp_name )
101-
102- # MCTS+RL algorithms related core code
103- policy_config = cfg .policy
104- replay_buffer = GameBuffer (policy_config )
111+ replay_buffer = GameBuffer (cfg .policy )
105112 collector = Collector (env = collector_env , policy = policy .collect_mode , tb_logger = tb_logger , exp_name = cfg .exp_name ,
106- policy_config = policy_config )
113+ policy_config = cfg . policy )
107114 evaluator = Evaluator (eval_freq = cfg .policy .eval_freq , n_evaluator_episode = cfg .env .n_evaluator_episode ,
108115 stop_value = cfg .env .stop_value , env = evaluator_env , policy = policy .eval_mode ,
109- tb_logger = tb_logger , exp_name = cfg .exp_name , policy_config = policy_config )
116+ tb_logger = tb_logger , exp_name = cfg .exp_name , policy_config = cfg . policy )
110117
111- # Learner 's before_run hook
118+ # Execute the learner 's before_run hook
112119 learner .call_hook ('before_run' )
113- if policy_config .use_wandb :
120+
121+ if cfg .policy .use_wandb :
114122 policy .set_train_iter_env_step (learner .train_iter , collector .envstep )
115123
116- # Collect random data before training
124+ # Randomly collect data if specified
117125 if cfg .policy .random_collect_episode_num > 0 :
126+ logging .info ("Collecting random data..." )
118127 random_collect (cfg .policy , policy , LightZeroRandomPolicy , collector , collector_env , replay_buffer )
128+ logging .info ("Random data collection completed!" )
119129
120130 batch_size = policy ._cfg .batch_size
121131
132+ if cfg .policy .multi_gpu :
133+ # Get current world size and rank
134+ world_size = get_world_size ()
135+ rank = get_rank ()
136+ else :
137+ world_size = 1
138+ rank = 0
139+
122140 while True :
123- # Log buffer memory usage
141+ # Log memory usage of the replay buffer
124142 log_buffer_memory_usage (learner .train_iter , replay_buffer , tb_logger )
125143
126- # Set temperature for visit count distributions
144+ # Set temperature parameter for data collection
127145 collect_kwargs = {
128146 'temperature' : visit_count_temperature (
129- policy_config .manual_temperature_decay ,
130- policy_config .fixed_temperature_value ,
131- policy_config .threshold_training_steps_for_final_temperature ,
147+ cfg . policy .manual_temperature_decay ,
148+ cfg . policy .fixed_temperature_value ,
149+ cfg . policy .threshold_training_steps_for_final_temperature ,
132150 trained_steps = learner .train_iter
133151 ),
134152 'epsilon' : 0.0 # Default epsilon value
135153 }
136154
137- # Configure epsilon for epsilon -greedy exploration
138- if policy_config .eps .eps_greedy_exploration_in_collect :
155+ # Configure epsilon-greedy exploration
156+ if cfg . policy .eps .eps_greedy_exploration_in_collect :
139157 epsilon_greedy_fn = get_epsilon_greedy_fn (
140- start = policy_config .eps .start ,
141- end = policy_config .eps .end ,
142- decay = policy_config .eps .decay ,
143- type_ = policy_config .eps .type
158+ start = cfg . policy .eps .start ,
159+ end = cfg . policy .eps .end ,
160+ decay = cfg . policy .eps .decay ,
161+ type_ = cfg . policy .eps .type
144162 )
145163 collect_kwargs ['epsilon' ] = epsilon_greedy_fn (collector .envstep )
146164
147165 # Evaluate policy performance
148166 if evaluator .should_eval (learner .train_iter ):
167+ logging .info (f"Training iteration { learner .train_iter } : Starting evaluation..." )
149168 stop , reward = evaluator .eval (learner .save_checkpoint , learner .train_iter , collector .envstep )
169+ logging .info (f"Training iteration { learner .train_iter } : Evaluation completed, stop condition: { stop } , current reward: { reward } " )
150170 if stop :
171+ logging .info ("Stopping condition met, training ends!" )
151172 break
152173
153174 # Collect new data
154175 new_data = collector .collect (train_iter = learner .train_iter , policy_kwargs = collect_kwargs )
176+ logging .info (f"Rank { rank } , Training iteration { learner .train_iter } : New data collection completed!" )
155177
156178 # Determine updates per collection
157- update_per_collect = calculate_update_per_collect (cfg , new_data )
179+ update_per_collect = cfg .policy .update_per_collect
180+ if update_per_collect is None :
181+ update_per_collect = calculate_update_per_collect (cfg , new_data , world_size )
158182
159183 # Update replay buffer
160184 replay_buffer .push_game_segments (new_data )
161185 replay_buffer .remove_oldest_data_to_fit ()
162186
163- # Train the policy if sufficient data is available
187+ if world_size > 1 :
188+ # Synchronize all ranks before training
189+ try :
190+ dist .barrier ()
191+ except Exception as e :
192+ logging .error (f'Rank { rank } : Synchronization barrier failed, error: { e } ' )
193+ break
194+
195+ # Check if there is sufficient data for training
164196 if collector .envstep > cfg .policy .train_start_after_envsteps :
165197 if cfg .policy .sample_type == 'episode' :
166198 data_sufficient = replay_buffer .get_num_of_game_segments () > batch_size
167199 else :
168200 data_sufficient = replay_buffer .get_num_of_transitions () > batch_size
201+
169202 if not data_sufficient :
170203 logging .warning (
171- f'The data in replay_buffer is not sufficient to sample a mini-batch: '
204+ f'Rank { rank } : The data in replay_buffer is not sufficient to sample a mini-batch: '
172205 f'batch_size: { batch_size } , replay_buffer: { replay_buffer } . Continue to collect now ....'
173206 )
174207 continue
175208
209+ # Execute multiple training rounds
176210 for i in range (update_per_collect ):
177211 train_data = replay_buffer .sample (batch_size , policy )
178212 if cfg .policy .reanalyze_ratio > 0 and i % 20 == 0 :
179- # Clear caches and precompute positional embedding matrices
180- policy .recompute_pos_emb_diff_and_clear_cache () # TODO
181-
182- if policy_config .use_wandb :
213+ policy .recompute_pos_emb_diff_and_clear_cache ()
214+
215+ if cfg .policy .use_wandb :
183216 policy .set_train_iter_env_step (learner .train_iter , collector .envstep )
184217
185- train_data .append ({'train_which_component' : 'transformer' })
186- log_vars = learner .train (train_data , collector .envstep )
218+ train_data .append (learner .train_iter )
187219
220+ log_vars = learner .train (train_data , collector .envstep )
188221 if cfg .policy .use_priority :
189222 replay_buffer .update_priority (train_data , log_vars [0 ]['value_priority_orig' ])
190223
191224 policy .recompute_pos_emb_diff_and_clear_cache ()
192225
193226 # Check stopping criteria
194227 if collector .envstep >= max_env_step or learner .train_iter >= max_train_iter :
228+ logging .info ("Stopping condition met, training ends!" )
195229 break
196230
197231 learner .call_hook ('after_run' )
198- wandb .finish ()
199- return policy
232+ if cfg .policy .use_wandb :
233+ wandb .finish ()
234+ logging .info ("===== Training Completed =====" )
235+ return policy
0 commit comments