Skip to content

Commit bbce4ef

Browse files
Adds Student-Teacher Distillation
1 parent f80d475 commit bbce4ef

File tree

10 files changed

+518
-186
lines changed

10 files changed

+518
-186
lines changed

config/dummy_config.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ algorithm:
5151
#
5252
# @torch.no_grad()
5353
# def get_symmetric_states(
54-
# obs: Optional[torch.Tensor] = None, actions: Optional[torch.Tensor] = None, cfg: "BaseEnvCfg" = None, is_critic: bool = False,
54+
# obs: Optional[torch.Tensor] = None, actions: Optional[torch.Tensor] = None, cfg: "BaseEnvCfg" = None, obs_type: str = "policy"
5555
# ) -> Tuple[torch.Tensor, torch.Tensor]:
5656
#
5757
data_augmentation_func: null
@@ -87,7 +87,6 @@ runner:
8787
neptune_project: legged_gym
8888
wandb_project: legged_gym
8989
# -- load and resuming
90-
resume: false
9190
load_run: -1 # -1 means load latest run
9291
resume_path: null # updated from load_run and checkpoint
9392
checkpoint: -1 # -1 means load latest checkpoint

rsl_rl/algorithms/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
"""Implementation of different RL agents."""
77

8+
from .distillation import Distillation
89
from .ppo import PPO
910

10-
__all__ = ["PPO"]
11+
__all__ = ["PPO", "Distillation"]

rsl_rl/algorithms/distillation.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
# torch
7+
import torch.nn as nn
8+
import torch.optim as optim
9+
10+
# rsl-rl
11+
from rsl_rl.modules import StudentTeacher
12+
from rsl_rl.storage import RolloutStorage
13+
14+
15+
class Distillation:
16+
"""Distillation algorithm for training a student model to mimic a teacher model."""
17+
18+
policy: StudentTeacher
19+
"""The student teacher model."""
20+
21+
def __init__(
22+
self,
23+
policy,
24+
num_learning_epochs=1,
25+
gradient_length=15,
26+
learning_rate=1e-3,
27+
device="cpu",
28+
):
29+
self.device = device
30+
self.learning_rate = learning_rate
31+
32+
self.rnd = None # TODO: remove when runner has a proper base class
33+
34+
# distillation components
35+
self.policy = policy
36+
self.policy.to(self.device)
37+
self.storage = None # initialized later
38+
self.optimizer = optim.Adam(self.policy.student.parameters(), lr=self.learning_rate)
39+
self.transition = RolloutStorage.Transition()
40+
41+
# distillation parameters
42+
self.num_learning_epochs = num_learning_epochs
43+
self.gradient_length = gradient_length
44+
45+
self.num_updates = 0
46+
47+
def init_storage(
48+
self, training_type, num_envs, num_transitions_per_env, student_obs_shape, teacher_obs_shape, actions_shape
49+
):
50+
# create rollout storage
51+
self.storage = RolloutStorage(
52+
training_type,
53+
num_envs,
54+
num_transitions_per_env,
55+
student_obs_shape,
56+
teacher_obs_shape,
57+
actions_shape,
58+
None,
59+
self.device,
60+
)
61+
62+
def act(self, obs, teacher_obs):
63+
# compute the actions
64+
self.transition.actions = self.policy.act(obs).detach()
65+
self.transition.privileged_actions = self.policy.evaluate(teacher_obs).detach()
66+
# record the observations
67+
self.transition.observations = obs
68+
self.transition.privileged_observations = teacher_obs
69+
return self.transition.actions
70+
71+
def process_env_step(self, rewards, dones, infos):
72+
# record the rewards and dones
73+
self.transition.rewards = rewards
74+
self.transition.dones = dones
75+
# record the transition
76+
self.storage.add_transitions(self.transition)
77+
self.transition.clear()
78+
self.policy.reset(dones)
79+
80+
def update(self):
81+
self.num_updates += 1
82+
mean_behaviour_loss = 0
83+
loss = 0
84+
cnt = 0
85+
86+
for epoch in range(self.num_learning_epochs): # TODO unify num_steps_per_env and gradient_length
87+
self.policy.reset()
88+
self.policy.detach_hidden_states()
89+
for obs, _, _, privileged_actions in self.storage.generator():
90+
91+
# inference the student for gradient computation
92+
actions = self.policy.act_inference(obs)
93+
94+
# behaviour cloning loss
95+
behaviour_loss = nn.functional.mse_loss(actions, privileged_actions)
96+
97+
# total loss
98+
loss = loss + behaviour_loss
99+
100+
mean_behaviour_loss += behaviour_loss.item()
101+
cnt += 1
102+
103+
# gradient step
104+
if cnt % self.gradient_length == 0:
105+
self.optimizer.zero_grad()
106+
loss.backward()
107+
self.optimizer.step()
108+
self.policy.detach_hidden_states()
109+
loss = 0
110+
111+
mean_behaviour_loss /= cnt
112+
self.storage.clear()
113+
self.policy.reset() # TODO needed?
114+
115+
# construct the loss dictionary
116+
loss_dict = {"behaviour": mean_behaviour_loss}
117+
118+
return loss_dict

rsl_rl/algorithms/ppo.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
class PPO:
2020
"""Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""
2121

22-
actor_critic: ActorCritic
22+
policy: ActorCritic
2323
"""The actor critic module."""
2424

2525
def __init__(
2626
self,
27-
actor_critic,
27+
policy,
2828
num_learning_epochs=1,
2929
num_mini_batches=1,
3030
clip_param=0.2,
@@ -84,10 +84,10 @@ def __init__(
8484
self.symmetry = None
8585

8686
# PPO components
87-
self.actor_critic = actor_critic
88-
self.actor_critic.to(self.device)
87+
self.policy = policy
88+
self.policy.to(self.device)
8989
# Create optimizer
90-
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
90+
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
9191
# Create rollout storage
9292
self.storage: RolloutStorage = None # type: ignore
9393
self.transition = RolloutStorage.Transition()
@@ -103,41 +103,38 @@ def __init__(
103103
self.max_grad_norm = max_grad_norm
104104
self.use_clipped_value_loss = use_clipped_value_loss
105105

106-
def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
106+
def init_storage(
107+
self, training_type, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, actions_shape
108+
):
107109
# create memory for RND as well :)
108110
if self.rnd:
109111
rnd_state_shape = [self.rnd.num_states]
110112
else:
111113
rnd_state_shape = None
112114
# create rollout storage
113115
self.storage = RolloutStorage(
116+
training_type,
114117
num_envs,
115118
num_transitions_per_env,
116119
actor_obs_shape,
117120
critic_obs_shape,
118-
action_shape,
121+
actions_shape,
119122
rnd_state_shape,
120123
self.device,
121124
)
122125

123-
def test_mode(self):
124-
self.actor_critic.test()
125-
126-
def train_mode(self):
127-
self.actor_critic.train()
128-
129126
def act(self, obs, critic_obs):
130-
if self.actor_critic.is_recurrent:
131-
self.transition.hidden_states = self.actor_critic.get_hidden_states()
132-
# Compute the actions and values
133-
self.transition.actions = self.actor_critic.act(obs).detach()
134-
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
135-
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
136-
self.transition.action_mean = self.actor_critic.action_mean.detach()
137-
self.transition.action_sigma = self.actor_critic.action_std.detach()
127+
if self.policy.is_recurrent:
128+
self.transition.hidden_states = self.policy.get_hidden_states()
129+
# compute the actions and values
130+
self.transition.actions = self.policy.act(obs).detach()
131+
self.transition.values = self.policy.evaluate(critic_obs).detach()
132+
self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach()
133+
self.transition.action_mean = self.policy.action_mean.detach()
134+
self.transition.action_sigma = self.policy.action_std.detach()
138135
# need to record obs and critic_obs before env.step()
139136
self.transition.observations = obs
140-
self.transition.critic_observations = critic_obs
137+
self.transition.privileged_observations = critic_obs
141138
return self.transition.actions
142139

143140
def process_env_step(self, rewards, dones, infos):
@@ -164,14 +161,14 @@ def process_env_step(self, rewards, dones, infos):
164161
self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1
165162
)
166163

167-
# Record the transition
164+
# record the transition
168165
self.storage.add_transitions(self.transition)
169166
self.transition.clear()
170-
self.actor_critic.reset(dones)
167+
self.policy.reset(dones)
171168

172169
def compute_returns(self, last_critic_obs):
173170
# compute value for the last step
174-
last_values = self.actor_critic.evaluate(last_critic_obs).detach()
171+
last_values = self.policy.evaluate(last_critic_obs).detach()
175172
self.storage.compute_returns(
176173
last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
177174
)
@@ -192,7 +189,7 @@ def update(self): # noqa: C901
192189
mean_symmetry_loss = None
193190

194191
# generator for mini batches
195-
if self.actor_critic.is_recurrent:
192+
if self.policy.is_recurrent:
196193
generator = self.storage.recurrent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
197194
else:
198195
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
@@ -230,10 +227,10 @@ def update(self): # noqa: C901
230227
data_augmentation_func = self.symmetry["data_augmentation_func"]
231228
# returned shape: [batch_size * num_aug, ...]
232229
obs_batch, actions_batch = data_augmentation_func(
233-
obs=obs_batch, actions=actions_batch, env=self.symmetry["_env"], is_critic=False
230+
obs=obs_batch, actions=actions_batch, env=self.symmetry["_env"], obs_type="policy"
234231
)
235232
critic_obs_batch, _ = data_augmentation_func(
236-
obs=critic_obs_batch, actions=None, env=self.symmetry["_env"], is_critic=True
233+
obs=critic_obs_batch, actions=None, env=self.symmetry["_env"], obs_type="critic"
237234
)
238235
# compute number of augmentations per sample
239236
num_aug = int(obs_batch.shape[0] / original_batch_size)
@@ -246,19 +243,17 @@ def update(self): # noqa: C901
246243
returns_batch = returns_batch.repeat(num_aug, 1)
247244

248245
# Recompute actions log prob and entropy for current batch of transitions
249-
# Note: we need to do this because we updated the actor_critic with the new parameters
246+
# Note: we need to do this because we updated the policy with the new parameters
250247
# -- actor
251-
self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
252-
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
248+
self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
249+
actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch)
253250
# -- critic
254-
value_batch = self.actor_critic.evaluate(
255-
critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]
256-
)
251+
value_batch = self.policy.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
257252
# -- entropy
258253
# we only keep the entropy of the first augmentation (the original one)
259-
mu_batch = self.actor_critic.action_mean[:original_batch_size]
260-
sigma_batch = self.actor_critic.action_std[:original_batch_size]
261-
entropy_batch = self.actor_critic.entropy[:original_batch_size]
254+
mu_batch = self.policy.action_mean[:original_batch_size]
255+
sigma_batch = self.policy.action_std[:original_batch_size]
256+
entropy_batch = self.policy.entropy[:original_batch_size]
262257

263258
# KL
264259
if self.desired_kl is not None and self.schedule == "adaptive":
@@ -308,21 +303,21 @@ def update(self): # noqa: C901
308303
if not self.symmetry["use_data_augmentation"]:
309304
data_augmentation_func = self.symmetry["data_augmentation_func"]
310305
obs_batch, _ = data_augmentation_func(
311-
obs=obs_batch, actions=None, env=self.symmetry["_env"], is_critic=False
306+
obs=obs_batch, actions=None, env=self.symmetry["_env"], obs_type="policy"
312307
)
313308
# compute number of augmentations per sample
314309
num_aug = int(obs_batch.shape[0] / original_batch_size)
315310

316311
# actions predicted by the actor for symmetrically-augmented observations
317-
mean_actions_batch = self.actor_critic.act_inference(obs_batch.detach().clone())
312+
mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone())
318313

319314
# compute the symmetrically augmented actions
320315
# note: we are assuming the first augmentation is the original one.
321316
# We do not use the action_batch from earlier since that action was sampled from the distribution.
322317
# However, the symmetry loss is computed using the mean of the distribution.
323318
action_mean_orig = mean_actions_batch[:original_batch_size]
324319
_, actions_mean_symm_batch = data_augmentation_func(
325-
obs=None, actions=action_mean_orig, env=self.symmetry["_env"], is_critic=False
320+
obs=None, actions=action_mean_orig, env=self.symmetry["_env"], obs_type="policy"
326321
)
327322

328323
# compute the loss (we skip the first augmentation as it is the original one)
@@ -349,7 +344,7 @@ def update(self): # noqa: C901
349344
# -- For PPO
350345
self.optimizer.zero_grad()
351346
loss.backward()
352-
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
347+
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
353348
self.optimizer.step()
354349
# -- For RND
355350
if self.rnd_optimizer:
@@ -382,4 +377,15 @@ def update(self): # noqa: C901
382377
# -- Clear the storage
383378
self.storage.clear()
384379

385-
return mean_value_loss, mean_surrogate_loss, mean_entropy, mean_rnd_loss, mean_symmetry_loss
380+
# construct the loss dictionary
381+
loss_dict = {
382+
"value_function": mean_value_loss,
383+
"surrogate": mean_surrogate_loss,
384+
"entropy": mean_entropy,
385+
}
386+
if self.rnd:
387+
loss_dict["rnd"] = mean_rnd_loss
388+
if self.symmetry:
389+
loss_dict["symmetry"] = mean_symmetry_loss
390+
391+
return loss_dict

rsl_rl/modules/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,12 @@
99
from .actor_critic_recurrent import ActorCriticRecurrent
1010
from .normalizer import EmpiricalNormalization
1111
from .rnd import RandomNetworkDistillation
12+
from .student_teacher import StudentTeacher
1213

13-
__all__ = ["ActorCritic", "ActorCriticRecurrent", "EmpiricalNormalization", "RandomNetworkDistillation"]
14+
__all__ = [
15+
"ActorCritic",
16+
"ActorCriticRecurrent",
17+
"EmpiricalNormalization",
18+
"RandomNetworkDistillation",
19+
"StudentTeacher",
20+
]

rsl_rl/modules/actor_critic.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ def __init__(
7878
# disable args validation for speedup
7979
Normal.set_default_validate_args(False)
8080

81-
# seems that we get better performance without init
82-
# self.init_memory_weights(self.memory_a, 0.001, 0.)
83-
# self.init_memory_weights(self.memory_c, 0.001, 0.)
84-
8581
@staticmethod
8682
# not used at the moment
8783
def init_weights(sequential, scales):
@@ -135,3 +131,19 @@ def act_inference(self, observations):
135131
def evaluate(self, critic_observations, **kwargs):
136132
value = self.critic(critic_observations)
137133
return value
134+
135+
def load_state_dict(self, state_dict, strict=True):
136+
"""Load the parameters of the actor-critic model.
137+
138+
Args:
139+
state_dict (dict): State dictionary of the model.
140+
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
141+
module's state_dict() function.
142+
143+
Returns:
144+
bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
145+
`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
146+
"""
147+
148+
super().load_state_dict(state_dict, strict=strict)
149+
return True

0 commit comments

Comments
 (0)