88import torch
99import torch .nn as nn
1010import torch .optim as optim
11- import warnings
11+ from itertools import chain
1212
1313from rsl_rl .modules import ActorCritic
1414from rsl_rl .modules .rnd import RandomNetworkDistillation
@@ -43,13 +43,19 @@ def __init__(
4343 rnd_cfg : dict | None = None ,
4444 # Symmetry parameters
4545 symmetry_cfg : dict | None = None ,
46+ # Distributed training parameters
47+ multi_gpu_cfg : dict | None = None ,
4648 ):
49+ # device-related parameters
4750 self .device = device
48-
49- self .desired_kl = desired_kl
50- self .schedule = schedule
51- self .learning_rate = learning_rate
52- self .normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
51+ self .is_multi_gpu = multi_gpu_cfg is not None
52+ # Multi-GPU parameters
53+ if multi_gpu_cfg is not None :
54+ self .gpu_global_rank = multi_gpu_cfg ["global_rank" ]
55+ self .gpu_world_size = multi_gpu_cfg ["world_size" ]
56+ else :
57+ self .gpu_global_rank = 0
58+ self .gpu_world_size = 1
5359
5460 # RND components
5561 if rnd_cfg is not None :
@@ -68,7 +74,7 @@ def __init__(
6874 use_symmetry = symmetry_cfg ["use_data_augmentation" ] or symmetry_cfg ["use_mirror_loss" ]
6975 # Print that we are not using symmetry
7076 if not use_symmetry :
71- warnings . warn ("Symmetry not used for learning. We will use it for logging instead." )
77+ print ("Symmetry not used for learning. We will use it for logging instead." )
7278 # If function is a string then resolve it to a function
7379 if isinstance (symmetry_cfg ["data_augmentation_func" ], str ):
7480 symmetry_cfg ["data_augmentation_func" ] = string_to_callable (symmetry_cfg ["data_augmentation_func" ])
@@ -102,6 +108,10 @@ def __init__(
102108 self .lam = lam
103109 self .max_grad_norm = max_grad_norm
104110 self .use_clipped_value_loss = use_clipped_value_loss
111+ self .desired_kl = desired_kl
112+ self .schedule = schedule
113+ self .learning_rate = learning_rate
114+ self .normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
105115
106116 def init_storage (
107117 self , training_type , num_envs , num_transitions_per_env , actor_obs_shape , critic_obs_shape , actions_shape
@@ -267,11 +277,28 @@ def update(self): # noqa: C901
267277 )
268278 kl_mean = torch .mean (kl )
269279
270- if kl_mean > self .desired_kl * 2.0 :
271- self .learning_rate = max (1e-5 , self .learning_rate / 1.5 )
272- elif kl_mean < self .desired_kl / 2.0 and kl_mean > 0.0 :
273- self .learning_rate = min (1e-2 , self .learning_rate * 1.5 )
274-
280+ # Reduce the KL divergence across all GPUs
281+ if self .is_multi_gpu :
282+ torch .distributed .all_reduce (kl_mean , op = torch .distributed .ReduceOp .SUM )
283+ kl_mean /= self .gpu_world_size
284+
285+ # Update the learning rate
286+ # Perform this adaptation only on the main process
287+ # TODO: Is this needed? If KL-divergence is the "same" across all GPUs,
288+ # then the learning rate should be the same across all GPUs.
289+ if self .gpu_global_rank == 0 :
290+ if kl_mean > self .desired_kl * 2.0 :
291+ self .learning_rate = max (1e-5 , self .learning_rate / 1.5 )
292+ elif kl_mean < self .desired_kl / 2.0 and kl_mean > 0.0 :
293+ self .learning_rate = min (1e-2 , self .learning_rate * 1.5 )
294+
295+ # Update the learning rate for all GPUs
296+ if self .is_multi_gpu :
297+ lr_tensor = torch .tensor (self .learning_rate , device = self .device )
298+ torch .distributed .broadcast (lr_tensor , src = 0 )
299+ self .learning_rate = lr_tensor .item ()
300+
301+ # Update the learning rate for all parameter groups
275302 for param_group in self .optimizer .param_groups :
276303 param_group ["lr" ] = self .learning_rate
277304
@@ -335,21 +362,30 @@ def update(self): # noqa: C901
335362 if self .rnd :
336363 # predict the embedding and the target
337364 predicted_embedding = self .rnd .predictor (rnd_state_batch )
338- target_embedding = self .rnd .target (rnd_state_batch )
365+ target_embedding = self .rnd .target (rnd_state_batch ). detach ()
339366 # compute the loss as the mean squared error
340367 mseloss = torch .nn .MSELoss ()
341- rnd_loss = mseloss (predicted_embedding , target_embedding . detach () )
368+ rnd_loss = mseloss (predicted_embedding , target_embedding )
342369
343- # Gradient step
370+ # Compute the gradients
344371 # -- For PPO
345372 self .optimizer .zero_grad ()
346373 loss .backward ()
374+ # -- For RND
375+ if self .rnd :
376+ self .rnd_optimizer .zero_grad () # type: ignore
377+ rnd_loss .backward ()
378+
379+ # Collect gradients from all GPUs
380+ if self .is_multi_gpu :
381+ self .reduce_parameters ()
382+
383+ # Apply the gradients
384+ # -- For PPO
347385 nn .utils .clip_grad_norm_ (self .policy .parameters (), self .max_grad_norm )
348386 self .optimizer .step ()
349387 # -- For RND
350388 if self .rnd_optimizer :
351- self .rnd_optimizer .zero_grad ()
352- rnd_loss .backward ()
353389 self .rnd_optimizer .step ()
354390
355391 # Store the losses
@@ -389,3 +425,50 @@ def update(self): # noqa: C901
389425 loss_dict ["symmetry" ] = mean_symmetry_loss
390426
391427 return loss_dict
428+
429+ """
430+ Helper functions
431+ """
432+
433+ def broadcast_parameters (self ):
434+ """Broadcast model parameters to all GPUs."""
435+ # obtain the model parameters on current GPU
436+ model_params = [self .policy .state_dict ()]
437+ if self .rnd :
438+ model_params .append (self .rnd .predictor .state_dict ())
439+ # broadcast the model parameters
440+ torch .distributed .broadcast_object_list (model_params , src = 0 )
441+ # load the model parameters on all GPUs from source GPU
442+ self .policy .load_state_dict (model_params [0 ])
443+ if self .rnd :
444+ self .rnd .predictor .load_state_dict (model_params [1 ])
445+
446+ def reduce_parameters (self ):
447+ """Collect gradients from all GPUs and average them.
448+
449+ This function is called after the backward pass to synchronize the gradients across all GPUs.
450+ """
451+ # Create a tensor to store the gradients
452+ grads = [param .grad .view (- 1 ) for param in self .policy .parameters () if param .grad is not None ]
453+ if self .rnd :
454+ grads += [param .grad .view (- 1 ) for param in self .rnd .parameters () if param .grad is not None ]
455+ all_grads = torch .cat (grads )
456+
457+ # Average the gradients across all GPUs
458+ torch .distributed .all_reduce (all_grads , op = torch .distributed .ReduceOp .SUM )
459+ all_grads /= self .gpu_world_size
460+
461+ # Get all parameters
462+ all_params = self .policy .parameters ()
463+ if self .rnd :
464+ all_params = chain (all_params , self .rnd .parameters ())
465+
466+ # Update the gradients for all parameters with the reduced gradients
467+ offset = 0
468+ for param in all_params :
469+ if param .grad is not None :
470+ numel = param .numel ()
471+ # copy data back from shared buffer
472+ param .grad .data .copy_ (all_grads [offset : offset + numel ].view_as (param .grad .data ))
473+ # update the offset for the next parameter
474+ offset += numel
0 commit comments