@@ -31,9 +31,10 @@ def __init__(
3131 # Distributed training parameters
3232 multi_gpu_cfg : dict | None = None ,
3333 ) -> None :
34- # device -related parameters
34+ # Device -related parameters
3535 self .device = device
3636 self .is_multi_gpu = multi_gpu_cfg is not None
37+
3738 # Multi-GPU parameters
3839 if multi_gpu_cfg is not None :
3940 self .gpu_global_rank = multi_gpu_cfg ["global_rank" ]
@@ -42,25 +43,25 @@ def __init__(
4243 self .gpu_global_rank = 0
4344 self .gpu_world_size = 1
4445
45- # distillation components
46+ # Distillation components
4647 self .policy = policy
4748 self .policy .to (self .device )
48- self .storage = None # initialized later
49+ self .storage = None # Initialized later
4950
50- # initialize the optimizer
51+ # Initialize the optimizer
5152 self .optimizer = resolve_optimizer (optimizer )(self .policy .parameters (), lr = learning_rate )
5253
53- # initialize the transition
54+ # Initialize the transition
5455 self .transition = RolloutStorage .Transition ()
5556 self .last_hidden_states = None
5657
57- # distillation parameters
58+ # Distillation parameters
5859 self .num_learning_epochs = num_learning_epochs
5960 self .gradient_length = gradient_length
6061 self .learning_rate = learning_rate
6162 self .max_grad_norm = max_grad_norm
6263
63- # initialize the loss function
64+ # Initialize the loss function
6465 loss_fn_dict = {
6566 "mse" : nn .functional .mse_loss ,
6667 "huber" : nn .functional .huber_loss ,
@@ -80,7 +81,7 @@ def init_storage(
8081 obs : TensorDict ,
8182 actions_shape : tuple [int ],
8283 ) -> None :
83- # create rollout storage
84+ # Create rollout storage
8485 self .storage = RolloutStorage (
8586 training_type ,
8687 num_envs ,
@@ -91,23 +92,23 @@ def init_storage(
9192 )
9293
9394 def act (self , obs : TensorDict ) -> torch .Tensor :
94- # compute the actions
95+ # Compute the actions
9596 self .transition .actions = self .policy .act (obs ).detach ()
9697 self .transition .privileged_actions = self .policy .evaluate (obs ).detach ()
97- # record the observations
98+ # Record the observations
9899 self .transition .observations = obs
99100 return self .transition .actions
100101
101102 def process_env_step (
102103 self , obs : TensorDict , rewards : torch .Tensor , dones : torch .Tensor , extras : dict [str , torch .Tensor ]
103104 ) -> None :
104- # update the normalizers
105+ # Update the normalizers
105106 self .policy .update_normalization (obs )
106107
107- # record the rewards and dones
108+ # Record the rewards and dones
108109 self .transition .rewards = rewards
109110 self .transition .dones = dones
110- # record the transition
111+ # Record the transition
111112 self .storage .add_transitions (self .transition )
112113 self .transition .clear ()
113114 self .policy .reset (dones )
@@ -122,18 +123,18 @@ def update(self) -> dict[str, float]:
122123 self .policy .reset (hidden_states = self .last_hidden_states )
123124 self .policy .detach_hidden_states ()
124125 for obs , _ , privileged_actions , dones in self .storage .generator ():
125- # inference the student for gradient computation
126+ # Inference of the student for gradient computation
126127 actions = self .policy .act_inference (obs )
127128
128- # behavior cloning loss
129+ # Behavior cloning loss
129130 behavior_loss = self .loss_fn (actions , privileged_actions )
130131
131- # total loss
132+ # Total loss
132133 loss = loss + behavior_loss
133134 mean_behavior_loss += behavior_loss .item ()
134135 cnt += 1
135136
136- # gradient step
137+ # Gradient step
137138 if cnt % self .gradient_length == 0 :
138139 self .optimizer .zero_grad ()
139140 loss .backward ()
@@ -145,7 +146,7 @@ def update(self) -> dict[str, float]:
145146 self .policy .detach_hidden_states ()
146147 loss = 0
147148
148- # reset dones
149+ # Reset dones
149150 self .policy .reset (dones .view (- 1 ))
150151 self .policy .detach_hidden_states (dones .view (- 1 ))
151152
@@ -154,22 +155,18 @@ def update(self) -> dict[str, float]:
154155 self .last_hidden_states = self .policy .get_hidden_states ()
155156 self .policy .detach_hidden_states ()
156157
157- # construct the loss dictionary
158+ # Construct the loss dictionary
158159 loss_dict = {"behavior" : mean_behavior_loss }
159160
160161 return loss_dict
161162
162- """
163- Helper functions
164- """
165-
166163 def broadcast_parameters (self ) -> None :
167164 """Broadcast model parameters to all GPUs."""
168- # obtain the model parameters on current GPU
165+ # Obtain the model parameters on current GPU
169166 model_params = [self .policy .state_dict ()]
170- # broadcast the model parameters
167+ # Broadcast the model parameters
171168 torch .distributed .broadcast_object_list (model_params , src = 0 )
172- # load the model parameters on all GPUs from source GPU
169+ # Load the model parameters on all GPUs from source GPU
173170 self .policy .load_state_dict (model_params [0 ])
174171
175172 def reduce_parameters (self ) -> None :
@@ -188,7 +185,7 @@ def reduce_parameters(self) -> None:
188185 for param in self .policy .parameters ():
189186 if param .grad is not None :
190187 numel = param .numel ()
191- # copy data back from shared buffer
188+ # Copy data back from shared buffer
192189 param .grad .data .copy_ (all_grads [offset : offset + numel ].view_as (param .grad .data ))
193- # update the offset for the next parameter
190+ # Update the offset for the next parameter
194191 offset += numel
0 commit comments