Skip to content

Commit d354d59

Browse files
Adds NAN check to avoid ambiguous std >= 0.0 error (#185)
1 parent 1e851f4 commit d354d59

File tree

6 files changed

+44
-15
lines changed

6 files changed

+44
-15
lines changed

docs/guide/configuration.rst

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ Currently, RSL-RL implements two runner classes:
5252
* - ``obs_groups``
5353
- dict[str, list[str]]
5454
- required
55-
- Mapping from observation sets to observation tensors coming from the environment.
56-
* - ``run_name``
57-
- str
58-
- missing
59-
- Optional run label shown in the console output.
55+
- Mapping from observation sets to observation groups coming from the environment. See :ref:`here <observation-configuration>` for more details.
6056
* - ``save_interval``
6157
- int
6258
- required
@@ -73,6 +69,14 @@ Currently, RSL-RL implements two runner classes:
7369
- str
7470
- required for Neptune
7571
- Neptune project name used by the Neptune writer.
72+
* - ``run_name``
73+
- str
74+
- missing
75+
- Optional run label shown in the console output.
76+
* - ``check_for_nan``
77+
- bool
78+
- ``True``
79+
- Whether to check for NaN values coming from the environment.
7680
* - ``algorithm``
7781
- dict
7882
- required

rsl_rl/algorithms/ppo.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def update(self) -> dict[str, float]:
231231
# Check if we should normalize advantages per mini batch
232232
if self.normalize_advantage_per_mini_batch:
233233
with torch.no_grad():
234-
batch.advantages = (batch.advantages - batch.advantages.mean()) / (batch.advantages.std() + 1e-8)
234+
batch.advantages = (batch.advantages - batch.advantages.mean()) / (batch.advantages.std() + 1e-8) # type: ignore
235235

236236
# Perform symmetric augmentation
237237
if self.symmetry and self.symmetry["use_data_augmentation"]:
@@ -259,7 +259,7 @@ def update(self) -> dict[str, float]:
259259
hidden_state=batch.hidden_states[0],
260260
stochastic_output=True,
261261
)
262-
actions_log_prob = self.actor.get_output_log_prob(batch.actions)
262+
actions_log_prob = self.actor.get_output_log_prob(batch.actions) # type: ignore
263263
values = self.critic(batch.observations, masks=batch.masks, hidden_state=batch.hidden_states[1])
264264
# Note: We only keep the distribution parameters and entropy of the first augmentation (the original one)
265265
distribution_params = tuple(p[:original_batch_size] for p in self.actor.output_distribution_params)
@@ -268,7 +268,7 @@ def update(self) -> dict[str, float]:
268268
# Compute KL divergence and adapt the learning rate
269269
if self.desired_kl is not None and self.schedule == "adaptive":
270270
with torch.inference_mode():
271-
kl = self.actor.get_kl_divergence(batch.old_distribution_params, distribution_params)
271+
kl = self.actor.get_kl_divergence(batch.old_distribution_params, distribution_params) # type: ignore
272272
kl_mean = torch.mean(kl)
273273

274274
# Reduce the KL divergence across all GPUs
@@ -294,9 +294,9 @@ def update(self) -> dict[str, float]:
294294
param_group["lr"] = self.learning_rate
295295

296296
# Surrogate loss
297-
ratio = torch.exp(actions_log_prob - torch.squeeze(batch.old_actions_log_prob))
298-
surrogate = -torch.squeeze(batch.advantages) * ratio
299-
surrogate_clipped = -torch.squeeze(batch.advantages) * torch.clamp(
297+
ratio = torch.exp(actions_log_prob - torch.squeeze(batch.old_actions_log_prob)) # type: ignore
298+
surrogate = -torch.squeeze(batch.advantages) * ratio # type: ignore
299+
surrogate_clipped = -torch.squeeze(batch.advantages) * torch.clamp( # type: ignore
300300
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
301301
)
302302
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

rsl_rl/runners/on_policy_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from rsl_rl.algorithms import PPO
1414
from rsl_rl.env import VecEnv
1515
from rsl_rl.models import MLPModel
16-
from rsl_rl.utils import resolve_callable
16+
from rsl_rl.utils import check_nan, resolve_callable
1717
from rsl_rl.utils.logger import Logger
1818

1919

@@ -85,6 +85,9 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
8585
actions = self.alg.act(obs)
8686
# Step the environment
8787
obs, rewards, dones, extras = self.env.step(actions.to(self.env.device))
88+
# Check for NaN values from the environment
89+
if self.cfg.get("check_for_nan", True):
90+
check_nan(obs, rewards, dones)
8891
# Move to device
8992
obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device))
9093
# Process the step

rsl_rl/storage/rollout_storage.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def generator(self) -> Generator[Batch, None, None]:
213213

214214
for i in range(self.num_transitions_per_env):
215215
yield RolloutStorage.Batch(
216-
observations=self.observations[i],
216+
observations=self.observations[i], # type: ignore
217217
privileged_actions=self.privileged_actions[i],
218218
dones=self.dones[i],
219219
)
@@ -312,14 +312,14 @@ def recurrent_mini_batch_generator(
312312

313313
# Yield the mini-batch
314314
yield RolloutStorage.Batch(
315-
observations=padded_obs_trajectories[:, first_traj:last_traj],
315+
observations=padded_obs_trajectories[:, first_traj:last_traj], # type: ignore
316316
actions=self.actions[:, start:stop],
317317
values=self.values[:, start:stop],
318318
advantages=self.advantages[:, start:stop],
319319
returns=self.returns[:, start:stop],
320320
old_actions_log_prob=self.actions_log_prob[:, start:stop],
321321
old_distribution_params=tuple(p[:, start:stop] for p in self.distribution_params), # type: ignore
322-
hidden_states=(hidden_state_a_batch, hidden_state_c_batch),
322+
hidden_states=(hidden_state_a_batch, hidden_state_c_batch), # type: ignore
323323
masks=trajectory_masks[:, first_traj:last_traj],
324324
)
325325

rsl_rl/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""Helper functions."""
77

88
from .utils import (
9+
check_nan,
910
get_param,
1011
resolve_callable,
1112
resolve_nn_activation,
@@ -16,6 +17,7 @@
1617
)
1718

1819
__all__ = [
20+
"check_nan",
1921
"get_param",
2022
"resolve_callable",
2123
"resolve_nn_activation",

rsl_rl/utils/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,26 @@ def resolve_obs_groups(
272272
return obs_groups
273273

274274

275+
def check_nan(obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor) -> None:
276+
"""Raise ``ValueError`` if any environment output contains NaN."""
277+
for key, tensor in obs.items():
278+
if torch.isnan(tensor).any():
279+
raise ValueError(
280+
f"The observation group '{key}' returned by the environment contains NaN values. This usually indicates"
281+
" a bug in the environment's step() or reset() function."
282+
)
283+
if torch.isnan(rewards).any():
284+
raise ValueError(
285+
"The rewards returned by the environment contain NaN values. This usually indicates a bug in the"
286+
" environment's reward computation."
287+
)
288+
if torch.isnan(dones).any():
289+
raise ValueError(
290+
"The dones returned by the environment contain NaN values. This usually indicates a bug in the"
291+
" environment's termination logic."
292+
)
293+
294+
275295
def split_and_pad_trajectories(
276296
tensor: torch.Tensor | TensorDict, dones: torch.Tensor
277297
) -> tuple[torch.Tensor | TensorDict, torch.Tensor]:

0 commit comments

Comments
 (0)