Skip to content

Commit c6834de

Browse files
alessandroassirelli98alessandro.assirelli
andauthored
Adds gradient cap for teacher student distillation (#91)
This PR adds a gradient cap to the teacher-student distillation setup. The goal is to prevent excessively large gradients from destabilizing training. 📌 Changes Introduced a clipping mechanism to cap the gradients during backpropagation in the distillation process. Helps improve training stability, especially in early iterations. --------- Co-authored-by: alessandro.assirelli <[email protected]>
1 parent d38a378 commit c6834de

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

rsl_rl/algorithms/distillation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
num_learning_epochs=1,
2626
gradient_length=15,
2727
learning_rate=1e-3,
28+
max_grad_norm=None,
2829
loss_type="mse",
2930
device="cpu",
3031
# Distributed training parameters
@@ -55,6 +56,7 @@ def __init__(
5556
self.num_learning_epochs = num_learning_epochs
5657
self.gradient_length = gradient_length
5758
self.learning_rate = learning_rate
59+
self.max_grad_norm = max_grad_norm
5860

5961
# initialize the loss function
6062
if loss_type == "mse":
@@ -127,6 +129,8 @@ def update(self):
127129
loss.backward()
128130
if self.is_multi_gpu:
129131
self.reduce_parameters()
132+
if self.max_grad_norm:
133+
nn.utils.clip_grad_norm_(self.policy.student.parameters(), self.max_grad_norm)
130134
self.optimizer.step()
131135
self.policy.detach_hidden_states()
132136
loss = 0

rsl_rl/runners/on_policy_runner.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev
9393

9494
# initialize algorithm
9595
alg_class = eval(self.alg_cfg.pop("class_name"))
96-
self.alg: PPO | Distillation = alg_class(policy, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg)
96+
self.alg: PPO | Distillation = alg_class(
97+
policy, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
98+
)
9799

98100
# store training configuration
99101
self.num_steps_per_env = self.cfg["num_steps_per_env"]
@@ -387,8 +389,13 @@ def log(self, locs: dict, width: int = 80, pad: int = 35):
387389
f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
388390
f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
389391
f"""{'Time elapsed:':>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time))}\n"""
390-
f"""{'ETA:':>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time / (locs['it'] - locs['start_iter'] + 1) * (
391-
locs['start_iter'] + locs['num_learning_iterations'] - locs['it'])))}\n"""
392+
f"""{'ETA:':>{pad}} {time.strftime(
393+
"%H:%M:%S",
394+
time.gmtime(
395+
self.tot_time / (locs['it'] - locs['start_iter'] + 1)
396+
* (locs['start_iter'] + locs['num_learning_iterations'] - locs['it'])
397+
)
398+
)}\n"""
392399
)
393400
print(log_string)
394401

@@ -513,16 +520,20 @@ def _configure_multi_gpu(self):
513520

514521
# check if user has device specified for local rank
515522
if self.device != f"cuda:{self.gpu_local_rank}":
516-
raise ValueError(f"Device '{self.device}' does not match expected device for local rank '{self.gpu_local_rank}'.")
523+
raise ValueError(
524+
f"Device '{self.device}' does not match expected device for local rank '{self.gpu_local_rank}'."
525+
)
517526
# validate multi-gpu configuration
518527
if self.gpu_local_rank >= self.gpu_world_size:
519-
raise ValueError(f"Local rank '{self.gpu_local_rank}' is greater than or equal to world size '{self.gpu_world_size}'.")
528+
raise ValueError(
529+
f"Local rank '{self.gpu_local_rank}' is greater than or equal to world size '{self.gpu_world_size}'."
530+
)
520531
if self.gpu_global_rank >= self.gpu_world_size:
521-
raise ValueError(f"Global rank '{self.gpu_global_rank}' is greater than or equal to world size '{self.gpu_world_size}'.")
532+
raise ValueError(
533+
f"Global rank '{self.gpu_global_rank}' is greater than or equal to world size '{self.gpu_world_size}'."
534+
)
522535

523536
# initialize torch distributed
524-
torch.distributed.init_process_group(
525-
backend="nccl", rank=self.gpu_global_rank, world_size=self.gpu_world_size
526-
)
537+
torch.distributed.init_process_group(backend="nccl", rank=self.gpu_global_rank, world_size=self.gpu_world_size)
527538
# set device to the local rank
528539
torch.cuda.set_device(self.gpu_local_rank)

0 commit comments

Comments
 (0)