diff --git a/cares_reinforcement_learning/algorithm/policy/CTD4.py b/cares_reinforcement_learning/algorithm/policy/CTD4.py index f6d5bd14..81cf8317 100644 --- a/cares_reinforcement_learning/algorithm/policy/CTD4.py +++ b/cares_reinforcement_learning/algorithm/policy/CTD4.py @@ -56,9 +56,12 @@ import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.algorithm.policy import TD3 +from cares_reinforcement_learning.memory.memory_buffer import SARLMemoryBuffer from cares_reinforcement_learning.networks.CTD4 import Actor, Critic +from cares_reinforcement_learning.types.episode import EpisodeContext from cares_reinforcement_learning.types.observation import SARLObservation from cares_reinforcement_learning.util.configurations import CTD4Config +from cares_reinforcement_learning.util.helpers import LinearScheduler class CTD4(TD3): @@ -81,6 +84,15 @@ def __init__( self.fusion_method = config.fusion_method + self.kalman_beta_scheduler = LinearScheduler( + start_value=config.kalman_beta_start, + end_value=config.kalman_beta_end, + decay_steps=config.kalman_beta_decay, + ) + self.kalman_beta = self.kalman_beta_scheduler.get_value(0) + + self.kalman_rho = config.kalman_rho + self.lr_ensemble_critic = config.critic_lr self.ensemble_critic_optimizers = [ torch.optim.Adam( @@ -111,88 +123,236 @@ def _calculate_value(self, state: SARLObservation, action: np.ndarray) -> float: q_u_set.append(actor_q_u) q_std_set.append(actor_q_std) - fusion_u_a, _ = self._fuse_critic_outputs(1, q_u_set, q_std_set) + fusion_u_a, _, _ = self._fuse_critic_outputs(1, q_u_set, q_std_set) return fusion_u_a.item() - def _fusion_kalman( + def _kalman_covariance( + self, + u_set: list[torch.Tensor], + std_set: list[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Covariance Intersection (CI) fusion for UNKNOWN correlation. + Non-sequential, order-invariant. + + CI formula (1D): + P^{-1} = sum_i ω_i P_i^{-1} + μ = P * sum_i ω_i P_i^{-1} μ_i + with ω_i >= 0, sum_i ω_i = 1. + + Returns: + fusion_u: (B,1) + fusion_std: (B,1) + weights: (B,E) normalized information contributions (for logging/intuition) + """ + u_mat = torch.concat(u_set, dim=1) # (B,E) + std_mat = torch.concat(std_set, dim=1) # (B,E) + + eps = 1e-12 + var_mat = std_mat**2 + eps + prec_mat = 1.0 / var_mat # (B,E) + + batch_size, num_critics = u_mat.shape + + # ω vector over critics (must sum to 1) + # Default: uniform CI (recommended) + + omega_vec = torch.full( + (num_critics,), 1.0 / num_critics, device=u_mat.device, dtype=u_mat.dtype + ) + + omega_mat = omega_vec.view(1, num_critics).expand( + batch_size, num_critics + ) # (B,E) + + fused_prec = (omega_mat * prec_mat).sum(dim=1, keepdim=True) # (B,1) + fusion_var = 1.0 / (fused_prec + eps) + 1e-6 + fusion_std = torch.sqrt(fusion_var) + + fusion_u = fusion_var * (omega_mat * prec_mat * u_mat).sum(dim=1, keepdim=True) + + # "weights" as normalized information contribution (not a unique CI object, but very useful) + info_contrib = omega_mat * prec_mat + weights = info_contrib / (info_contrib.sum(dim=1, keepdim=True) + eps) + + return fusion_u, fusion_std, weights + + def _kalman_correlated( self, - std_1: torch.Tensor, - mean_1: torch.Tensor, - std_2: torch.Tensor, - mean_2: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - kalman_gain = (std_1**2) / (std_1**2 + std_2**2) - fusion_mean = mean_1 + kalman_gain * (mean_2 - mean_1) - fusion_variance = ( - (1 - kalman_gain) * std_1**2 + kalman_gain * std_2**2 + 1e-6 - ) # 1e-6 was included to avoid values equal to 0 - fusion_std = torch.sqrt(fusion_variance) - return fusion_mean, fusion_std - - def _kalman( + u_set: list[torch.Tensor], + std_set: list[torch.Tensor], + batch_size: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Correlated Gaussian fusion (1D) using a simple correlation model: + Cov(fused, new) = rho * sqrt(P_fused * P_new) + + rho=0 recovers the independent Kalman update. + rho>0 makes fusion more conservative (less variance collapse). + """ + num_critics = len(u_set) + fusion_u = u_set[0] # (B,1) + fusion_std = std_set[0] # (B,1) + + weights = torch.zeros( + (batch_size, num_critics), device=self.device, dtype=torch.float32 + ) + weights[:, 0] = 1.0 + + eps = 1e-12 + + for i in range(1, num_critics): + x2 = u_set[i] # (B,1) + std2 = std_set[i] # (B,1) + + var1 = fusion_std**2 + var2 = std2**2 + + # Cross-covariance model + cov12 = self.kalman_rho * torch.sqrt(var1 * var2 + eps) # (B,1) + + den = (var1 + var2 - 2.0 * cov12) + eps + kalman_gain = (var1 - cov12) / den + kalman_gain = kalman_gain.clamp(0.0, 1.0) + + fusion_u = fusion_u + kalman_gain * (x2 - fusion_u) + + # Correlated posterior variance + fusion_variance = var1 - kalman_gain * (var1 - cov12) + 1e-6 + fusion_std = torch.sqrt(fusion_variance) + + # weights update (same structure) + weights = weights * (1.0 - kalman_gain) + weights[:, i : i + 1] = weights[:, i : i + 1] + kalman_gain + + return fusion_u, fusion_std, weights + + def _kalman_interpolated( + self, u_set: list[torch.Tensor], std_set: list[torch.Tensor], batch_size: int + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + num_critics = len(u_set) + # Start from critic 0 + fusion_u = u_set[0] # (B,1) + fusion_std = std_set[0] # (B,1) + + # weights: (B,E), start fully on critic 0 + weights = torch.zeros( + (batch_size, num_critics), device=self.device, dtype=torch.float32 + ) + weights[:, 0] = 1.0 + + # Fuse critics 1..E-1 sequentially + for i in range(1, num_critics): + x2 = u_set[i] # (B,1) + std2 = std_set[i] # (B,1) + + # Kalman gain: trust weight on the NEW critic (x2) relative to current fused estimate + # K close to 1 -> new critic dominates; K close to 0 -> old fused dominates + kalman_gain = (fusion_std**2) / (fusion_std**2 + std2**2 + 1e-12) # (B,1) + + # Mean fusion: fused <- (1-K) * fused + K * x2 + fusion_u = fusion_u + kalman_gain * (x2 - fusion_u) + + # Variance fusion ("interpolated" rule): var <- (1-K)*var1 + K*var2 + # Correlations are ignored, so this is not a true Kalman update but an interpolation that avoids variance collapse. + fusion_variance = ( + (1 - kalman_gain) * (fusion_std**2) + kalman_gain * (std2**2) + 1e-6 + ) + fusion_std = torch.sqrt(fusion_variance) + + # Weight update: + # - all existing contributions get down-weighted by (1-K) + # - new critic i gets weight K + weights = weights * (1 - kalman_gain) # broadcast (B,E) * (B,1) + weights[:, i : i + 1] = ( + weights[:, i : i + 1] + kalman_gain + ) # add (B,1) into column i + + return fusion_u, fusion_std, weights + + def _kalman_precision( self, u_set: list[torch.Tensor], std_set: list[torch.Tensor] - ) -> tuple[torch.Tensor, torch.Tensor]: - # Kalman fusion - for i in range(len(u_set) - 1): - if i == 0: - x_1, std_1 = u_set[i], std_set[i] - x_2, std_2 = u_set[i + 1], std_set[i + 1] - fusion_u, fusion_std = self._fusion_kalman(std_1, x_1, std_2, x_2) - else: - x_2, std_2 = u_set[i + 1], std_set[i + 1] - fusion_u, fusion_std = self._fusion_kalman( - fusion_std, fusion_u, std_2, x_2 - ) - return fusion_u, fusion_std + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + u_mat = torch.concat(u_set, dim=1) # (B,E) + std_mat = torch.concat(std_set, dim=1) # (B,E) + + eps = 1e-12 + precision = 1.0 / (std_mat**2 + eps) # (B,E) + + # Temper precision: beta=0 => all ones => uniform weights + precision_t = precision.pow(self.kalman_beta) + precision_sum = precision_t.sum(dim=1, keepdim=True) # (B,1) + + weights = precision_t / (precision_sum + eps) # (B,E) + fusion_u = (weights * u_mat).sum(dim=1, keepdim=True) # (B,1) + + # Tempered fused variance (consistent with tempered weights). + fusion_var = 1.0 / (precision_sum + eps) + fusion_std = torch.sqrt(fusion_var + 1e-6) + + return fusion_u, fusion_std, weights def _average( self, u_set: list[torch.Tensor], std_set: list[torch.Tensor], batch_size: int - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Average value among the critic predictions: - fusion_u = ( - torch.mean(torch.concat(u_set, dim=1), dim=1) - .unsqueeze(0) - .reshape(batch_size, 1) - ) - fusion_std = ( - torch.mean(torch.concat(std_set, dim=1), dim=1) - .unsqueeze(0) - .reshape(batch_size, 1) + u_mat = torch.concat(u_set, dim=1) # (B,E) + std_mat = torch.concat(std_set, dim=1) # (B,E) + + fusion_u = u_mat.mean(dim=1, keepdim=True) # (B,1) + fusion_std = std_mat.mean(dim=1, keepdim=True) # (B,1) + + num_critics = u_mat.shape[1] + weights = torch.full( + (batch_size, num_critics), 1.0 / num_critics, device=u_mat.device ) - return fusion_u, fusion_std + + return fusion_u, fusion_std, weights def _minimum( self, u_set: list[torch.Tensor], std_set: list[torch.Tensor], batch_size: int - ) -> tuple[torch.Tensor, torch.Tensor]: - fusion_min = torch.min(torch.concat(u_set, dim=1), dim=1) - fusion_u = fusion_min.values.unsqueeze(0).reshape(batch_size, 1) - # # This corresponds to the std of the min U index. That is; the min cannot be got between the stds - std_concat = torch.concat(std_set, dim=1) - fusion_std = ( - torch.stack( - [std_concat[i, fusion_min.indices[i]] for i in range(len(std_concat))] - ) - .unsqueeze(0) - .reshape(batch_size, 1) - ) - return fusion_u, fusion_std + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + u_mat = torch.concat(u_set, dim=1) # (B,E) + std_mat = torch.concat(std_set, dim=1) # (B,E) + + min_vals, min_idx = torch.min(u_mat, dim=1) # (B,), (B,) + fusion_u = min_vals.unsqueeze(1) # (B,1) + + # std of the selected critic + fusion_std = std_mat[torch.arange(batch_size), min_idx].unsqueeze(1) # (B,1) + + # one-hot weights + num_critics = u_mat.shape[1] + weights = torch.zeros((batch_size, num_critics), device=u_mat.device) + weights[torch.arange(batch_size), min_idx] = 1.0 + + return fusion_u, fusion_std, weights def _fuse_critic_outputs( self, batch_size: int, u_set: list[torch.Tensor], std_set: list[torch.Tensor] - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.fusion_method == "kalman": - fusion_u, fusion_std = self._kalman(u_set, std_set) + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.fusion_method == "precision": + fusion_u, fusion_std, weights = self._kalman_precision(u_set, std_set) + elif self.fusion_method == "interpolated": + fusion_u, fusion_std, weights = self._kalman_interpolated( + u_set, std_set, batch_size + ) + elif self.fusion_method == "correlated": + fusion_u, fusion_std, weights = self._kalman_correlated( + u_set, std_set, batch_size + ) + elif self.fusion_method == "covariance": + fusion_u, fusion_std, weights = self._kalman_covariance(u_set, std_set) elif self.fusion_method == "average": - fusion_u, fusion_std = self._average(u_set, std_set, batch_size) + fusion_u, fusion_std, weights = self._average(u_set, std_set, batch_size) elif self.fusion_method == "minimum": - fusion_u, fusion_std = self._minimum(u_set, std_set, batch_size) + fusion_u, fusion_std, weights = self._minimum(u_set, std_set, batch_size) else: - raise ValueError( - f"Invalid fusion method: {self.fusion_method}. Please choose between 'kalman', 'average', or 'minimum'." - ) + raise ValueError(f"Invalid fusion method: {self.fusion_method}.") - return fusion_u, fusion_std + return fusion_u, fusion_std, weights def _update_critic( self, @@ -203,6 +363,8 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, # pylint: disable=unused-argument ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + batch_size = len(states) with torch.no_grad(): @@ -225,7 +387,9 @@ def _update_critic( u_set.append(u) std_set.append(std) - fusion_u, fusion_std = self._fuse_critic_outputs(batch_size, u_set, std_set) + fusion_u, fusion_std, fusion_weights = self._fuse_critic_outputs( + batch_size, u_set, std_set + ) # Create the target distribution = aX+b u_target = rewards + self.gamma * fusion_u * (1 - dones) @@ -238,6 +402,10 @@ def _update_critic( critic_loss_totals = [] critic_loss_elementwise = [] + # --- current (s,a) ensemble output health (optional but useful) --- + current_mu_means: list[float] = [] + current_sigma_means: list[float] = [] + for critic_net, critic_net_optimiser in zip( self.critic_net.critics, self.ensemble_critic_optimizers ): @@ -246,7 +414,7 @@ def _update_critic( u_current, std_current ) - # Compute each critic los + # Compute each critic loss as KL divergence to the target distribution critic_elementwise_loss = torch.distributions.kl.kl_divergence( current_distribution, target_distribution ) @@ -255,28 +423,132 @@ def _update_critic( critic_loss = critic_elementwise_loss.mean() critic_loss_totals.append(critic_loss.item()) + # If σ collapses while KL stays high -> overconfident wrong critic (bad calibration) + current_mu_means.append(u_current.mean().item()) + current_sigma_means.append(std_current.mean().item()) + critic_net_optimiser.zero_grad() critic_loss.backward() critic_net_optimiser.step() - critic_losses = torch.stack(critic_loss_elementwise, dim=0) - critic_losses = torch.max(critic_losses, dim=0).values + kl_stack = torch.stack(critic_loss_elementwise, dim=0) + critic_max_per_sample = torch.max(kl_stack, dim=0).values # Update the Priorities - PER only priorities = ( - critic_losses.clamp(self.min_priority) + critic_max_per_sample.clamp(self.min_priority) .pow(self.per_alpha) .cpu() .data.numpy() .flatten() ) - critic_loss_total = np.mean(critic_loss_totals) + with torch.no_grad(): + # --- TD3-style smoothing diagnostics --- + # Noise diagnostics + # What it tells you: + # - target_noise_abs_mean: effective smoothing magnitude. + # - target_noise_clip_frac high early: noise often clipped (clip too small or noise too large). + target_noise_abs_mean = target_noise.abs().mean().item() + target_noise_clip_frac = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + info["target_noise_abs_mean"] = float(target_noise_abs_mean) + info["target_noise_clip_frac"] = float(target_noise_clip_frac) + + # How different are the critics’ average predicted means from each other (on the current batch)? + info["mu_std_across_critics"] = float(np.std(current_mu_means)) + info["sigma_std_across_critics"] = float(np.std(current_sigma_means)) - info = { - "critic_loss_total": critic_loss_total, - "critic_loss_totals": critic_loss_totals, - } + # --- Target ensemble diagnostics (s', a') --- + u_mat = torch.concat(u_set, dim=1) # (B, E) + std_mat = torch.concat(std_set, dim=1) # (B, E) + + # mu_std_mean is “ensemble disagreement” (epistemic spread). Spikes/growth often = divergence/OOD. + info["target_ensemble_mu_mean"] = u_mat.mean().item() + info["target_ensemble_mu_std_mean"] = ( + u_mat.std(dim=1, unbiased=False).mean().item() + ) + + # sigma_mean is average predicted uncertainty; collapse = overconfidence; explosion = instability. + info["target_ensemble_sigma_mean"] = std_mat.mean().item() + info["target_ensemble_sigma_std"] = std_mat.std().item() + + # --- Fusion diagnostics (s', a') --- + # Fused μ is the value signal used for the target + # Fused σ is the “post-fusion uncertainty”; should not collapse too early + info["fusion_mu_mean"] = fusion_u.mean().item() + info["fusion_mu_std"] = fusion_u.std().item() + + info["fusion_sigma_mean"] = fusion_std.mean().item() + info["fusion_sigma_std"] = fusion_std.std().item() + + # weights: (B, E) -- contribution of each critic to fused estimate + eps = 1e-12 + + # Dominance: how much the most trusted critic contributes + # w_max ≈ 1/E -> equal trust + # w_max → 1.0 -> single critic dominating (ensemble collapse) + w_max = fusion_weights.max(dim=1).values # (B,) + + info["fusion_w_max_mean"] = w_max.mean().item() + info["fusion_w_max_p95"] = w_max.quantile(0.95).item() + + # Entropy of weights: distribution of trust + # High entropy -> distributed trust across critics + # Low entropy -> sharp trust concentration + entropy = -(fusion_weights * (fusion_weights + eps).log()).sum( + dim=1 + ) # (B,) + info["fusion_w_entropy_mean"] = entropy.mean().item() + info["fusion_w_entropy_std"] = entropy.std().item() + + # Effective Ensemble Size + # N_eff = 1 / sum(w_k^2) + # ≈ E -> all critics contributing + # ≈ 1 -> effectively a single critic + n_eff = 1.0 / (fusion_weights.pow(2).sum(dim=1) + eps) # (B,) + + info["fusion_n_eff_mean"] = n_eff.mean().item() + info["fusion_n_eff_p10"] = n_eff.quantile(0.10).item() + + # --- Target distribution diagnostics --- + # Drift upward without reward improvement: gamma/reward_scale/instability. + info["u_target_mean"] = u_target.mean().item() + info["u_target_std"] = u_target.std().item() + + # Collapse -> overconfident targets; explosion -> noisy targets / unstable critics. + info["std_target_mean"] = std_target.mean().item() + info["std_target_std"] = std_target.std().item() + + # --- Critic loss diagnostics (fit quality) --- + # If one critic stays high: “bad apple” critic, poor calibration, or optimizer issue. + info["critic_loss_total"] = float(np.mean(critic_loss_totals)) + info["critic_loss_totals"] = critic_loss_totals + + # --- KL diagnostics (more robust than mean loss alone) --- + # ---- Mean KL across critics (overall fit quality) ---- + kl_mean_per_sample = kl_stack.mean(dim=0) # (B,1) + + info["kl_mean"] = kl_mean_per_sample.mean().item() + info["kl_mean_std"] = kl_mean_per_sample.std().item() + + # ---- Max KL across critics (worst critic instability) ---- + # Spikes: distribution mismatch / exploding σ / unstable learning. + info["kl_max_mean"] = critic_max_per_sample.mean().item() + info["kl_max_std"] = critic_max_per_sample.std().item() + info["kl_max_p95"] = critic_max_per_sample.quantile(0.95).item() + + # --- Current ensemble health on replay (s, a) --- + # σ collapsing while KL remains high => overconfident wrong critics. + info["current_ensemble_mu_mean"] = float(np.mean(current_mu_means)) + info["current_ensemble_sigma_mean"] = float(np.mean(current_sigma_means)) + + # --- PER priority health --- + # priority_max exploding => PER may over-focus on a few transitions and destabilize training. + info["priority_mean"] = float(np.mean(priorities)) + info["priority_p95"] = float(np.quantile(priorities, 0.95)) + info["priority_max"] = float(np.max(priorities)) return info, priorities @@ -285,10 +557,16 @@ def _update_actor( states: torch.Tensor, weights: torch.Tensor, # pylint: disable=unused-argument ) -> dict[str, Any]: + info: dict[str, Any] = {} + batch_size = len(states) - actor_q_u_set = [] - actor_q_std_set = [] + actor_q_u_set: list[torch.Tensor] = [] + actor_q_std_set: list[torch.Tensor] = [] + + # Track per-critic batch means for “bad apple” detection (like critic update) + current_mu_means: list[float] = [] + current_sigma_means: list[float] = [] actions = self.actor_net(states) with hlp.evaluating(self.critic_net): @@ -298,20 +576,109 @@ def _update_actor( actor_q_u_set.append(actor_q_u) actor_q_std_set.append(actor_q_std) - fusion_u_a, _ = self._fuse_critic_outputs( + current_mu_means.append(actor_q_u.mean().item()) + current_sigma_means.append(actor_q_std.mean().item()) + + fusion_u_a, fusion_std_a, fusion_weights_a = self._fuse_critic_outputs( batch_size, actor_q_u_set, actor_q_std_set ) actor_loss = -fusion_u_a.mean() + # --------------------------------------------------------- + # Deterministic Policy Gradient Strength (∇a Q(s,a)) + # --------------------------------------------------------- + # Measures how steep the critic surface is w.r.t. actions. + # ~0 early -> critic flat, actor receives no learning signal. + # Very large -> critic overly sharp, can cause unstable actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=actions, + retain_graph=True, # because we do backward(actor_loss) next + create_graph=False, # diagnostic only + allow_unused=False, + )[0] + with torch.no_grad(): + # - ~0 early: critic surface flat around actor actions (weak learning signal) + # - very large: critic surface sharp -> unstable / exploitative actor updates + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() - info = { - "actor_loss": actor_loss.item(), - } + with torch.no_grad(): + + # Policy Action Health (tanh policies in [-1, 1]) + # pi_action_saturation_frac: + # High values (>0.8 early) often mean the actor is slamming bounds, + # reducing effective gradient flow through tanh. + info["pi_action_mean"] = actions.mean().item() + info["pi_action_std"] = actions.std().item() + info["pi_action_abs_mean"] = actions.abs().mean().item() + info["pi_action_saturation_frac"] = ( + (actions.abs() > 0.95).float().mean().item() + ) + + # --- Actor-side ensemble diagnostics (s, pi(s)) --- + u_mat = torch.concat(actor_q_u_set, dim=1) # (B,E) + std_mat = torch.concat(actor_q_std_set, dim=1) # (B,E) + + # Per-sample disagreement across critics on μ under current policy (epistemic spread). + info["actor_ensemble_mu_mean"] = u_mat.mean().item() + info["actor_ensemble_mu_std_mean"] = ( + u_mat.std(dim=1, unbiased=False).mean().item() + ) + + # Average predicted uncertainty under current policy; collapse/explosion are red flags. + info["actor_ensemble_sigma_mean"] = std_mat.mean().item() + info["actor_ensemble_sigma_std"] = std_mat.std(unbiased=False).item() + + # --- “Bad apple” ensemble drift (across critics, coarse) --- + # If one critic drifts, these rise even if the per-sample std looks OK. + info["actor_mu_std_across_critics"] = float(np.std(current_mu_means)) + info["actor_sigma_std_across_critics"] = float(np.std(current_sigma_means)) + # --- Actor-side fusion outputs --- + info["actor_fusion_mu_mean"] = fusion_u_a.mean().item() + info["actor_fusion_mu_std"] = fusion_u_a.std(unbiased=False).item() + info["actor_fusion_sigma_mean"] = fusion_std_a.mean().item() + info["actor_fusion_sigma_std"] = fusion_std_a.std(unbiased=False).item() + + # --- Fusion weight diagnostics (actor-side) --- + eps = 1e-12 + + # Dominance: near 1 => single critic dominating under actor actions. + w_max = fusion_weights_a.max(dim=1).values # (B,) + info["actor_fusion_w_max_mean"] = w_max.mean().item() + info["actor_fusion_w_max_p95"] = w_max.quantile(0.95).item() + + # Diversity of trust across critics. + entropy = -(fusion_weights_a * (fusion_weights_a + eps).log()).sum( + dim=1 + ) # (B,) + info["actor_fusion_w_entropy_mean"] = entropy.mean().item() + info["actor_fusion_w_entropy_std"] = entropy.std(unbiased=False).item() + + # Effective ensemble size: near 1 => effectively single-critic behavio + n_eff = 1.0 / (fusion_weights_a.pow(2).sum(dim=1) + eps) # (B,) + info["actor_fusion_n_eff_mean"] = n_eff.mean().item() + info["actor_fusion_n_eff_p10"] = n_eff.quantile(0.10).item() + + info["actor_loss"] = actor_loss.item() + + return info + + def train( + self, memory_buffer: SARLMemoryBuffer, episode_context: EpisodeContext + ) -> dict[str, Any]: + self.kalman_beta = self.kalman_beta_scheduler.get_value( + episode_context.training_step + ) + info = super().train(memory_buffer, episode_context) + info["kalman_beta"] = self.kalman_beta return info def save_models(self, filepath: str, filename: str) -> None: diff --git a/cares_reinforcement_learning/algorithm/policy/CrossQ.py b/cares_reinforcement_learning/algorithm/policy/CrossQ.py index cbf45d41..d2710064 100644 --- a/cares_reinforcement_learning/algorithm/policy/CrossQ.py +++ b/cares_reinforcement_learning/algorithm/policy/CrossQ.py @@ -91,6 +91,7 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} with torch.no_grad(): with hlp.evaluating(self.actor_net): @@ -104,15 +105,13 @@ def _update_critic( q_values_one, q_values_one_next = torch.chunk(cat_q_values_one, chunks=2, dim=0) q_values_two, q_values_two_next = torch.chunk(cat_q_values_two, chunks=2, dim=0) - target_q_values = ( - torch.minimum(q_values_one_next, q_values_two_next) - - self.alpha * next_log_pi - ) + with torch.no_grad(): + min_next_q = torch.minimum(q_values_one_next, q_values_two_next) + next_q_values = min_next_q - self.alpha * next_log_pi - q_target = ( - rewards * self.reward_scale + self.gamma * (1 - dones) * target_q_values - ) - torch.detach(q_target) + q_target = ( + rewards * self.reward_scale + self.gamma * (1 - dones) * next_q_values + ) td_error_one = (q_values_one - q_target).abs() td_error_two = (q_values_two - q_target).abs() @@ -139,10 +138,84 @@ def _update_critic( .flatten() ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --------------------------------------------------------- + # CrossQ-specific diagnostics + # --------------------------------------------------------- + # (1) Self-bootstrap "optimism": if q_next is systematically larger than q_now + # it can indicate overestimation pressure since the same network supplies bootstrap values. + info["q_next_minus_q_mean"] = ( + (min_next_q - torch.minimum(q_values_one, q_values_two)).mean().item() + ) + + # (2) Next-Q magnitude vs current-Q magnitude (scale drift check) + info["q_next_abs_mean"] = min_next_q.abs().mean().item() + info["q_abs_mean"] = ( + torch.minimum(q_values_one, q_values_two).abs().mean().item() + ) + + # (3) CrossQ concatenation health: are the two halves numerically similar in distribution? + # Big discrepancies can indicate distribution shift or implementation bugs in cat/chunk wiring. + info["crossq_half_gap_abs_mean"] = ( + (torch.minimum(q_values_one, q_values_two).mean() - min_next_q.mean()) + .abs() + .item() + ) + + # --------------------------------------------------------- + # CrossQ "bootstrap-from-self" next critics (s',a') + # (these are NOT target critics; they're the next-half outputs) + # --------------------------------------------------------- + info["q1_next_mean"] = q_values_one_next.mean().item() + info["q2_next_mean"] = q_values_two_next.mean().item() + info["q_next_twin_gap_abs_mean"] = ( + (q_values_one_next - q_values_two_next).abs().mean().item() + ) + + # --------------------------------------------------------- + # Soft target decomposition (same as SAC, but using q_next from critic_net) + # --------------------------------------------------------- + # alpha_log_pi is typically negative; entropy_bonus is typically positive + alpha_log_pi = self.alpha * next_log_pi + # this is what gets ADDED to minQ in the target + entropy_bonus = -self.alpha * next_log_pi + + soft_target_value = min_next_q + entropy_bonus # == minQ - alpha*log_pi + + info["next_min_q_mean"] = min_next_q.mean().item() + info["alpha_log_pi_mean"] = alpha_log_pi.mean().item() + info["entropy_bonus_mean"] = entropy_bonus.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std(unbiased=False).item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std(unbiased=False).item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std(unbiased=False).item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities diff --git a/cares_reinforcement_learning/algorithm/policy/DDPG.py b/cares_reinforcement_learning/algorithm/policy/DDPG.py index fd63bd35..9b8fda63 100644 --- a/cares_reinforcement_learning/algorithm/policy/DDPG.py +++ b/cares_reinforcement_learning/algorithm/policy/DDPG.py @@ -67,6 +67,7 @@ SARLObservationTensors, ) from cares_reinforcement_learning.util.configurations import DDPGConfig +from cares_reinforcement_learning.util.helpers import ExponentialScheduler class DDPG(SARLAlgorithm[np.ndarray]): @@ -88,6 +89,16 @@ def __init__( self.gamma = config.gamma self.tau = config.tau + # Action noise + self.action_noise_scheduler = ExponentialScheduler( + start_value=config.action_noise_start, + end_value=config.action_noise_end, + decay_steps=config.action_noise_decay, + ) + self.action_noise = self.action_noise_scheduler.get_value(0) + + self.action_num = self.actor_net.num_actions + self.actor_net_optimiser = torch.optim.Adam( self.actor_net.parameters(), lr=config.actor_lr ) @@ -95,20 +106,30 @@ def __init__( self.critic_net.parameters(), lr=config.critic_lr ) - # TODO add action noise for exploration + self.learn_counter = 0 + def act( self, observation: SARLObservation, evaluation: bool = False ) -> ActionSample[np.ndarray]: - # pylint: disable-next=unused-argument + self.actor_net.eval() + state = observation.vector_state - self.actor_net.eval() with torch.no_grad(): state_tensor = torch.FloatTensor(state).to(self.device) state_tensor = state_tensor.unsqueeze(0) action = self.actor_net(state_tensor) action = action.cpu().data.numpy().flatten() + if not evaluation: + # this is part the DDPG too, add noise to the action + noise = np.random.normal( + 0, scale=self.action_noise, size=self.action_num + ).astype(np.float32) + action = action + noise + action = np.clip(action, -1, 1) + self.actor_net.train() + return ActionSample(action=action, source="policy") def _update_critic( @@ -119,6 +140,8 @@ def _update_critic( next_states: torch.Tensor, dones: torch.Tensor, ) -> dict[str, Any]: + info: dict[str, Any] = {} + with torch.no_grad(): self.target_actor_net.eval() next_actions = self.target_actor_net(next_states) @@ -134,25 +157,81 @@ def _update_critic( critic_loss.backward() self.critic_net_optimiser.step() - info = { - "critic_loss": critic_loss.item(), - } + with torch.no_grad(): + td = q_values - q_target + + # --- Q statistics --- + info["q_mean"] = q_values.mean().item() + info["q_std"] = q_values.std().item() + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + info["td_mean"] = td.mean().item() + info["td_std"] = td.std().item() + info["td_abs_mean"] = td.abs().mean().item() + + # --- Losses (optimization progress --- + info["critic_loss"] = critic_loss.item() return info def _update_actor(self, states: torch.Tensor) -> dict[str, Any]: - self.critic_net.eval() - actions_pred = self.actor_net(states) - actor_q = self.critic_net(states, actions_pred) - self.critic_net.train() + info: dict[str, Any] = {} - actor_loss = -actor_q.mean() + actions = self.actor_net(states) + + with hlp.evaluating(self.critic_net): + actor_q_values = self.critic_net(states, actions) + + actor_loss = -actor_q_values.mean() + + # --------------------------------------------------------- + # Deterministic Policy Gradient Strength (∇a Q(s,a)) + # --------------------------------------------------------- + # Measures how steep the critic surface is w.r.t. actions. + # ~0 early -> critic flat, actor receives no learning signal. + # Very large -> critic overly sharp, can cause unstable actor updates. + dq_da = torch.autograd.grad( + outputs=-actor_q_values.mean(), # NOTE: uses Q-term only, excludes regularizers + inputs=actions, + retain_graph=True, # needed because we will backward (actor_loss) next + create_graph=False, # diagnostic only + allow_unused=False, + )[0] + with torch.no_grad(): + # - ~0 early: critic surface flat around actor actions (weak learning signal) + # - very large: critic surface sharp -> unstable / exploitative actor updates + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() - info = {"actor_loss": actor_loss.item()} + with torch.no_grad(): + # Policy Action Health (tanh policies in [-1, 1]) + # pi_action_saturation_frac: + # High values (>0.8 early) often mean the actor is slamming bounds, + # reducing effective gradient flow through tanh. + info["pi_action_mean"] = actions.mean().item() + info["pi_action_std"] = actions.std().item() + info["pi_action_abs_mean"] = actions.abs().mean().item() + info["pi_action_saturation_frac"] = ( + (actions.abs() > 0.95).float().mean().item() + ) + + # actor_q_mean should generally increase over training. + # actor_q_std large + unstable may indicate critic inconsistency. + info["actor_loss"] = actor_loss.item() + info["actor_q_mean"] = actor_q_values.mean().item() + info["actor_q_std"] = actor_q_values.std().item() + return info def update_from_batch( @@ -164,9 +243,13 @@ def update_from_batch( next_observation_tensor: SARLObservationTensors, dones_tensor: torch.Tensor, ) -> dict[str, Any]: + self.learn_counter += 1 + info: dict[str, Any] = {} - # TODO add the action noise for exploration with episode context and some decay mechanism + self.action_noise = self.action_noise_scheduler.get_value( + episode_context.training_step + ) # Update Critic critic_info = self._update_critic( diff --git a/cares_reinforcement_learning/algorithm/policy/IMARL.py b/cares_reinforcement_learning/algorithm/policy/IMARL.py index 2f79f6af..bd4ce5e6 100644 --- a/cares_reinforcement_learning/algorithm/policy/IMARL.py +++ b/cares_reinforcement_learning/algorithm/policy/IMARL.py @@ -215,7 +215,15 @@ def train( indices=indices, ) for key, value in agent_i_info.items(): - info[f"{agent_name}_{key}"] = value + info[f"agent_{i}_{key}"] = value + + metrics = list(agent_i_info.keys()) + for metric in metrics: + values = [info[f"agent_{i}_{metric}"] for i in range(self.num_agents)] + info[f"mean_{metric}"] = float(np.mean(values)) + info[f"std_{metric}"] = float(np.std(values)) + info[f"max_{metric}"] = float(np.max(values)) + info[f"min_{metric}"] = float(np.min(values)) return info diff --git a/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py b/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py index e360f523..400aec40 100644 --- a/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/LA3PSAC.py @@ -89,6 +89,7 @@ def _update_critic( # type: ignore[override] sample: Sample[SingleAgentExperience], uniform_sampling: bool, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} # Convert into tensors using helper method ( @@ -166,11 +167,59 @@ def _update_critic( # type: ignore[override] .flatten() ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + info["uniform_sampling"] = float(uniform_sampling) + + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Soft target decomposition (SAC-specific) --- + # min_target_q_mean: the conservative bootstrap value from twin critics (pre-entropy) + # entropy_term_mean: magnitude of entropy regularization in the target (alpha * log_pi is usually negative) + # soft_target_value_mean: the exact term used inside the Bellman target before reward/discount + min_target_q = torch.minimum(target_q_values_one, target_q_values_two) + entropy_term = self.alpha * next_log_pi # typically negative + soft_target_value = min_target_q - entropy_term # == minQ - alpha*log_pi + + info["target_min_q_mean"] = min_target_q.mean().item() + info["entropy_term_mean"] = entropy_term.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities @@ -210,7 +259,6 @@ def train( observation_tensor.vector_state_tensor, weights_tensor ) info_uniform |= actor_info - info_uniform["alpha"] = self.alpha.item() if target_update: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) @@ -248,7 +296,6 @@ def train( observation_tensor.vector_state_tensor, weights_tensor ) info_priority |= actor_info - info_priority["alpha"] = self.alpha.item() info = {"uniform": info_uniform, "priority": info_priority} diff --git a/cares_reinforcement_learning/algorithm/policy/LA3PTD3.py b/cares_reinforcement_learning/algorithm/policy/LA3PTD3.py index ae0c4025..7c431316 100644 --- a/cares_reinforcement_learning/algorithm/policy/LA3PTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/LA3PTD3.py @@ -94,6 +94,8 @@ def _update_critic( # type: ignore[override] sample: Sample[SingleAgentExperience], uniform_sampling: bool, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + # Convert into tensors using helper method ( observation_tensor, @@ -173,11 +175,57 @@ def _update_critic( # type: ignore[override] .flatten() ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- TD3-style smoothing diagnostics --- + # Noise diagnostics + # What it tells you: + # - target_noise_abs_mean: effective smoothing magnitude. + # - target_noise_clip_frac high early: noise often clipped (clip too small or noise too large). + target_noise_abs_mean = target_noise.abs().mean().item() + target_noise_clip_frac = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + info["target_noise_abs_mean"] = float(target_noise_abs_mean) + info["target_noise_clip_frac"] = float(target_noise_clip_frac) + + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities @@ -215,10 +263,10 @@ def train( device=self.device, ) - actor_loss = self._update_actor( + actor_info = self._update_actor( observation_tensor.vector_state_tensor, weights_tensor ) - info_uniform["actor_loss"] = actor_loss + info_uniform |= actor_info self._update_target_network() diff --git a/cares_reinforcement_learning/algorithm/policy/LAPSAC.py b/cares_reinforcement_learning/algorithm/policy/LAPSAC.py index 4ddc3dd7..3131114d 100644 --- a/cares_reinforcement_learning/algorithm/policy/LAPSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/LAPSAC.py @@ -84,6 +84,8 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + with torch.no_grad(): with hlp.evaluating(self.actor_net): next_actions, next_log_pi, _ = self.actor_net(next_states) @@ -121,16 +123,63 @@ def _update_critic( priorities = ( torch.max(td_error_one, td_error_two) + .clamp(min=self.min_priority) .pow(self.per_alpha) .cpu() .data.numpy() .flatten() ) - info = { - "critic_loss_one": huber_lose_one.item(), - "critic_loss_two": huber_lose_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Soft target decomposition (SAC-specific) --- + # min_target_q_mean: the conservative bootstrap value from twin critics (pre-entropy) + # entropy_term_mean: magnitude of entropy regularization in the target (alpha * log_pi is usually negative) + # soft_target_value_mean: the exact term used inside the Bellman target before reward/discount + min_target_q = torch.minimum(target_q_values_one, target_q_values_two) + entropy_term = self.alpha * next_log_pi # typically negative + soft_target_value = min_target_q - entropy_term # == minQ - alpha*log_pi + + info["target_min_q_mean"] = min_target_q.mean().item() + info["entropy_term_mean"] = entropy_term.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = huber_lose_one.item() + info["critic_loss_two"] = huber_lose_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities diff --git a/cares_reinforcement_learning/algorithm/policy/LAPTD3.py b/cares_reinforcement_learning/algorithm/policy/LAPTD3.py index 8f7ddd67..f4e6cb30 100644 --- a/cares_reinforcement_learning/algorithm/policy/LAPTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/LAPTD3.py @@ -84,6 +84,8 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + with torch.no_grad(): next_actions = self.target_actor_net(next_states) @@ -128,10 +130,56 @@ def _update_critic( .flatten() ) - info = { - "critic_loss_one": huber_lose_one.item(), - "critic_loss_two": huber_lose_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- TD3-style smoothing diagnostics --- + # Noise diagnostics + # What it tells you: + # - target_noise_abs_mean: effective smoothing magnitude. + # - target_noise_clip_frac high early: noise often clipped (clip too small or noise too large). + target_noise_abs_mean = target_noise.abs().mean().item() + target_noise_clip_frac = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + info["target_noise_abs_mean"] = float(target_noise_abs_mean) + info["target_noise_clip_frac"] = float(target_noise_clip_frac) + + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = huber_lose_one.item() + info["critic_loss_two"] = huber_lose_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities diff --git a/cares_reinforcement_learning/algorithm/policy/MADDPG.py b/cares_reinforcement_learning/algorithm/policy/MADDPG.py index a58340c5..796f1bf9 100644 --- a/cares_reinforcement_learning/algorithm/policy/MADDPG.py +++ b/cares_reinforcement_learning/algorithm/policy/MADDPG.py @@ -39,6 +39,7 @@ import torch.nn.functional as F import cares_reinforcement_learning.memory.memory_sampler as memory_sampler +import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.algorithm.algorithm import MARLAlgorithm from cares_reinforcement_learning.algorithm.policy.DDPG import DDPG from cares_reinforcement_learning.memory.memory_buffer import MARLMemoryBuffer @@ -201,7 +202,7 @@ def _compute_adversarial_actions( actions: torch.Tensor, # (batch, n_agents, act_dim) global_states: torch.Tensor, # (batch, state_dim) critic: torch.nn.Module, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: """ Return actions_adv where for j != agent_index: a_j_adv = a_j + eps_j @@ -209,7 +210,7 @@ def _compute_adversarial_actions( """ if self.m3_alpha == 0.0: # Degenerates to original MADDPG - return actions.detach() + return actions.detach(), torch.zeros_like(actions) # Clone and mark for gradient wrt actions only actions_for_grad = actions.detach().clone().requires_grad_(True) @@ -240,7 +241,7 @@ def _compute_adversarial_actions( eps = eps * mask actions_adv = actions_for_grad + eps - return actions_adv.detach() # no gradients through perturbation + return actions_adv.detach(), eps.detach() # no gradients through perturbation def _update_critic( self, @@ -253,10 +254,12 @@ def _update_critic( next_actions_tensor: torch.Tensor, # (B, N, act_dim) from target actors dones_i: torch.Tensor, ): + info: dict[str, Any] = {} + # --- Step 1: build (possibly adversarial) next joint actions --- if self.use_m3: # M3DDPG: perturb OTHER agents' target actions for agent i - next_actions_adv = self._compute_adversarial_actions( + next_actions_adv, eps = self._compute_adversarial_actions( agent_index=agent_index, actions=next_actions_tensor, # (B, N, act_dim) global_states=next_global_states, # (B, state_dim) @@ -289,7 +292,34 @@ def _update_critic( agent.critic_net_optimiser.step() - return {"critic_loss": loss.item()} + with torch.no_grad(): + + td = q_values - q_target + + # --- Value scale --- + info["q_mean"] = q_values.mean().item() + info["q_std"] = q_values.std(unbiased=False).item() + + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std(unbiased=False).item() + + # --- TD error diagnostics --- + td_abs = td.abs() + info["td_abs_mean"] = td_abs.mean().item() + info["td_abs_p95"] = td_abs.quantile(0.95).item() + info["td_abs_max"] = td_abs.max().item() + + # --- Signed bias --- + info["td_mean"] = td.mean().item() + + if self.use_m3: + info["critic_m3_eps_norm_mean"] = eps.norm(dim=-1).mean().item() + info["critic_m3_eps_norm_p95"] = eps.norm(dim=-1).quantile(0.95).item() + + # --- Critic loss --- + info["critic_loss"] = loss.item() + + return info def _update_actor( self, @@ -304,6 +334,7 @@ def _update_actor( - For j ≠ agent_index: use replay-buffer actions - For j == agent_index: use current actor output """ + info: dict[str, Any] = {} agent_ids = list(obs_tensors.keys()) batch_size = global_states.shape[0] @@ -318,16 +349,16 @@ def _update_actor( # Step 2: Replace ONLY agent_i action with differentiable action # --------------------------------------------------------- obs_i = obs_tensors[agent_ids[agent_index]] # (B, obs_dim_i) - pred_action_i = agent.actor_net(obs_i) # differentiable + actions_i = agent.actor_net(obs_i) # differentiable - actions_all[:, agent_index, :] = pred_action_i # keep others from buffer + actions_all[:, agent_index, :] = actions_i # keep others from buffer # --------------------------------------------------------- # Step 3a: Apply M3DDPG adversarial perturbation (if enabled) # --------------------------------------------------------- if self.use_m3: # compute perturbation on ALL actions (but this returns detached) - actions_adv = self._compute_adversarial_actions( + actions_adv, eps = self._compute_adversarial_actions( agent_index=agent_index, actions=actions_all, global_states=global_states, @@ -335,7 +366,7 @@ def _update_actor( ) # reinsert differentiable action for agent i - actions_adv[:, agent_index, :] = pred_action_i + actions_adv[:, agent_index, :] = actions_i actions_all = actions_adv # --------------------------------------------------------- @@ -352,18 +383,32 @@ def _update_actor( norm=self.ernie_norm, ) pred_action_adv = agent.actor_net(obs_i + delta_adv) - ernie_reg = (pred_action_adv - pred_action_i).pow(2).mean() + ernie_reg = (pred_action_adv - actions_i).pow(2).mean() # --------------------------------------------------------- # Step 4: Compute actor loss: -Q_i(x, a_1,...,a_i,...,a_N) # --------------------------------------------------------- joint_actions_flat = actions_all.reshape(batch_size, -1) - q_val = agent.critic_net(global_states, joint_actions_flat) + with hlp.evaluating(agent.critic_net): + actor_q_values = agent.critic_net(global_states, joint_actions_flat) # regularization as in TF code - reg = (pred_action_i**2).mean() * 1e-3 + reg = (actions_i**2).mean() * 1e-3 + + actor_loss = -actor_q_values.mean() + reg + (self.ernie_lambda * ernie_reg) - actor_loss = -q_val.mean() + reg + (self.ernie_lambda * ernie_reg) + dq_da = torch.autograd.grad( + outputs=-actor_q_values.mean(), # NOTE: uses Q-term only, excludes regularizers + inputs=actions_i, + retain_graph=True, + create_graph=False, + )[0] + with torch.no_grad(): + # - ~0 early: critic surface flat around actor actions (weak learning signal) + # - very large: critic surface sharp -> unstable / exploitative actor updates + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() # --------------------------------------------------------- # Step 5: Backprop @@ -378,7 +423,35 @@ def _update_actor( agent.actor_net_optimiser.step() - return {"actor_loss": actor_loss.item()} + with torch.no_grad(): + # Policy Action Health (tanh policies in [-1, 1]) + # pi_action_saturation_frac: + # High values (>0.8 early) often mean the actor is slamming bounds, + # reducing effective gradient flow through tanh. + info["pi_action_mean"] = actions_i.mean().item() + info["pi_action_std"] = actions_i.std().item() + info["pi_action_abs_mean"] = actions_i.abs().mean().item() + info["pi_action_saturation_frac"] = ( + (actions_i.abs() > 0.95).float().mean().item() + ) + + # actor_q_mean should generally increase over training. + # actor_q_std large + unstable may indicate critic inconsistency. + info["actor_loss"] = actor_loss.item() + info["actor_q_mean"] = actor_q_values.mean().item() + info["actor_q_std"] = actor_q_values.std().item() + + # --- ERNIE diagnostics --- + if self.use_ernie: + info["ernie_reg"] = ernie_reg.item() + + if self.use_m3: + info["actor_m3_eps_norm_mean"] = eps.norm(dim=-1).mean().item() + info["actor_m3_eps_norm_p95"] = eps.norm(dim=-1).quantile(0.95).item() + + info["actor_loss"] = actor_loss.item() + + return info def train( self, @@ -394,6 +467,16 @@ def train( # --------------------------------------------------------- # Update each agent # --------------------------------------------------------- + + # Update action noise for exploration (decayed over training) + current_agent.action_noise = current_agent.action_noise_scheduler.get_value( + episode_context.training_step + ) + + info[f"action_noise_agent_{agent_index}"] = float( + current_agent.action_noise + ) + ( observation_tensor, actions_tensor, @@ -434,6 +517,54 @@ def train( # Flatten replay-buffer actions for this batch joint_actions = actions_tensor.reshape(sample_size, -1) + with torch.no_grad(): + # --------------------------------------------------------- + # Batch-level multi-agent diagnostics (this agent's draw) + # --------------------------------------------------------- + # Joint action volatility in the replay batch (all agents) + info[f"agent_{agent_index}_joint_action_mean"] = ( + actions_tensor.mean().item() + ) + info[f"agent_{agent_index}_joint_action_std"] = actions_tensor.std( + unbiased=False + ).item() + + # Per-agent action magnitude (detect frozen/saturated agent in replay) + # actions_tensor: (B, N, A) + per_agent_abs_mean = actions_tensor.abs().mean(dim=(0, 2)) # (N,) + per_agent_std = actions_tensor.std(dim=(0, 2), unbiased=False) # (N,) + + info[f"agent_{agent_index}_replay_action_abs_mean"] = ( + per_agent_abs_mean.mean().item() + ) + info[f"agent_{agent_index}_replay_action_abs_std_across_agents"] = ( + per_agent_abs_mean.std(unbiased=False).item() + ) + info[f"agent_{agent_index}_replay_action_std_mean"] = ( + per_agent_std.mean().item() + ) + + # Coordination proxy: how aligned are agents' actions? (cheap) + # Cos similarity between agents' action vectors per sample, averaged. + # Flatten each agent action: (B, N, A) -> (B, N, A) + a = actions_tensor # (B,N,A) + a_norm = a / a.norm(dim=2, keepdim=True).clamp_min(1e-6) + + # pairwise cosine for all agent pairs + cos = torch.einsum("bna,bma->bnm", a_norm, a_norm) # (B,N,N) + # ignore diagonal + n = cos.shape[1] + mask = ~torch.eye(n, device=cos.device, dtype=torch.bool) + info[f"agent_{agent_index}_replay_action_cos_mean"] = ( + cos[:, mask].mean().item() + ) + + # Reward/done scale sanity for this agent (helps catch mis-scaling) + info[f"agent_{agent_index}_reward_mean"] = rewards_tensor.mean().item() + info[f"agent_{agent_index}_done_frac"] = ( + dones_tensor.float().mean().item() + ) + # --------------------------------------------------------- # Critic update for this agent # --------------------------------------------------------- @@ -450,6 +581,7 @@ def train( next_actions_tensor=next_actions_tensor, dones_i=dones_i, ) + info.update({f"agent_{agent_index}_{k}": v for k, v in critic_info.items()}) # --------------------------------------------------------- # Actor update @@ -461,10 +593,18 @@ def train( global_states=states_tensors, actions_tensor=actions_tensor, ) - - info[f"critic_loss_agent_{agent_index}"] = critic_info["critic_loss"] - info[f"actor_loss_agent_{agent_index}"] = actor_info["actor_loss"] - + info.update({f"agent_{agent_index}_{k}": v for k, v in actor_info.items()}) + + # --- Cross-agent diagnostics --- + metrics = list(critic_info.keys()) + list(actor_info.keys()) + for metric in metrics: + values = [info[f"agent_{i}_{metric}"] for i in range(self.num_agents)] + info[f"mean_{metric}"] = float(np.mean(values)) + info[f"std_{metric}"] = float(np.std(values)) + info[f"max_{metric}"] = float(np.max(values)) + info[f"min_{metric}"] = float(np.min(values)) + + # Update Target networks with soft update for current_agent in self.agent_networks: current_agent.update_target_networks() diff --git a/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py b/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py index a470bc65..c37957ea 100644 --- a/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py +++ b/cares_reinforcement_learning/algorithm/policy/MAPERSAC.py @@ -122,20 +122,22 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + # Get current Q estimates output_one, output_two = self.critic_net(states.detach(), actions.detach()) - q_value_one, predicted_reward_one, next_states_one = self._split_output( + q_values_one, predicted_rewards_one, next_states_one = self._split_output( output_one ) - q_value_two, predicted_reward_two, next_states_two = self._split_output( + q_values_two, predicted_rewards_two, next_states_two = self._split_output( output_two ) diff_reward_one = 0.5 * torch.pow( - predicted_reward_one.reshape(-1, 1) - rewards.reshape(-1, 1), 2.0 + predicted_rewards_one.reshape(-1, 1) - rewards.reshape(-1, 1), 2.0 ).reshape(-1, 1) diff_reward_two = 0.5 * torch.pow( - predicted_reward_two.reshape(-1, 1) - rewards.reshape(-1, 1), 2.0 + predicted_rewards_two.reshape(-1, 1) - rewards.reshape(-1, 1), 2.0 ).reshape(-1, 1) diff_next_states_one = 0.5 * torch.mean( @@ -160,28 +162,32 @@ def _update_critic( with hlp.evaluating(self.actor_net): next_actions, next_log_pi, _ = self.actor_net(next_states) - target_q_values_one, target_q_values_two = self.target_critic_net( + target_output_one, target_output_two = self.target_critic_net( next_states, next_actions ) - next_values_one, _, _ = self._split_output(target_q_values_one) - next_values_two, _, _ = self._split_output(target_q_values_two) - min_next_target = torch.minimum(next_values_one, next_values_two).reshape( - -1, 1 - ) + target_q_values_one, _, _ = self._split_output(target_output_one) + target_q_values_two, _, _ = self._split_output(target_output_two) + min_next_target = torch.minimum( + target_q_values_one, target_q_values_two + ).reshape(-1, 1) target_q_values = min_next_target - self.alpha * next_log_pi predicted_rewards = ( ( - predicted_reward_one.reshape(-1, 1) - + predicted_reward_two.reshape(-1, 1) + predicted_rewards_one.reshape(-1, 1) + + predicted_rewards_two.reshape(-1, 1) ) / 2 ).reshape(-1, 1) q_target = predicted_rewards + self.gamma * (1 - dones) * target_q_values - diff_td_one = F.mse_loss(q_value_one.reshape(-1, 1), q_target, reduction="none") - diff_td_two = F.mse_loss(q_value_two.reshape(-1, 1), q_target, reduction="none") + diff_td_one = F.mse_loss( + q_values_one.reshape(-1, 1), q_target, reduction="none" + ) + diff_td_two = F.mse_loss( + q_values_two.reshape(-1, 1), q_target, reduction="none" + ) critic_loss_one = ( diff_td_one @@ -242,23 +248,157 @@ def _update_critic( # Update Scales if self.learn_counter == 1: self.scale_r = np.mean(numpy_td_mean) / ( - np.mean(diff_next_state_mean_numpy) + np.mean(diff_reward_mean_numpy) + 1e-12 ) self.scale_s = np.mean(numpy_td_mean) / ( - np.mean(diff_next_state_mean_numpy) + np.mean(diff_next_state_mean_numpy) + 1e-12 ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + + # --- Component losses (unweighted, per-sample means) --- + # These tell you whether the model heads are actually learning and on what scale. + info["td_mse_one_mean"] = diff_td_one.mean().item() + info["td_mse_two_mean"] = diff_td_two.mean().item() + info["td_mse_mean"] = ( + 0.5 * (diff_td_one.mean() + diff_td_two.mean()) + ).item() + + info["reward_pred_mse_one_mean"] = diff_reward_one.mean().item() + info["reward_pred_mse_two_mean"] = diff_reward_two.mean().item() + info["reward_pred_mse_mean"] = ( + 0.5 * (diff_reward_one.mean() + diff_reward_two.mean()) + ).item() + + info["next_state_pred_mse_one_mean"] = diff_next_states_one.mean().item() + info["next_state_pred_mse_two_mean"] = diff_next_states_two.mean().item() + info["next_state_pred_mse_mean"] = ( + 0.5 * (diff_next_states_one.mean() + diff_next_states_two.mean()) + ).item() + + # --- Scales (very important to log; they define the tradeoff) --- + info["scale_r"] = float(self.scale_r) + info["scale_s"] = float(self.scale_s) + + # --- Weighted contribution ratios (are aux losses dominating TD?) --- + # These approximate how much each term contributes inside the critic loss before IS weighting. + td_term_mean = (0.5 * (diff_td_one.mean() + diff_td_two.mean())).item() + r_term_mean = ( + 0.5 * (diff_reward_one.mean() + diff_reward_two.mean()) + ).item() + s_term_mean = ( + 0.5 * (diff_next_states_one.mean() + diff_next_states_two.mean()) + ).item() + + info["loss_term_td_mean"] = float(td_term_mean) + info["loss_term_r_scaled_mean"] = float(self.scale_r * r_term_mean) + info["loss_term_s_scaled_mean"] = float(self.scale_s * s_term_mean) + + den = ( + td_term_mean + + (self.scale_r * r_term_mean) + + (self.scale_s * s_term_mean) + + 1e-12 + ) + info["loss_td_frac"] = float(td_term_mean / den) + info["loss_r_frac"] = float((self.scale_r * r_term_mean) / den) + info["loss_s_frac"] = float((self.scale_s * s_term_mean) / den) + + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Soft target decomposition (SAC-specific) --- + # min_target_q_mean: the conservative bootstrap value from twin critics (pre-entropy) + # entropy_term_mean: magnitude of entropy regularization in the target (alpha * log_pi is usually negative) + # soft_target_value_mean: the exact term used inside the Bellman target before reward/discount + min_target_q = torch.minimum( + target_q_values_one, target_q_values_two + ).reshape(-1, 1) + + # alpha_log_pi is typically negative; entropy_bonus is typically positive + alpha_log_pi = (self.alpha * next_log_pi).reshape(-1, 1) + entropy_bonus = (-self.alpha * next_log_pi).reshape(-1, 1) + + soft_target_value = min_target_q + entropy_bonus # == minQ - alpha*log_pi + + info["target_min_q_mean"] = min_target_q.mean().item() + info["alpha_log_pi_mean"] = alpha_log_pi.mean().item() + info["entropy_bonus_mean"] = entropy_bonus.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # --- Predicted reward bias (MaPER/MAPERSAC-specific) --- + # Since the Bellman target uses predicted reward instead of env reward, + # any systematic bias here directly shifts Q-targets and can cause + # value inflation or suppression. + predicted_reward_mean = predicted_rewards.mean().item() + env_reward_mean = rewards.mean().item() + + info["predicted_reward_mean"] = predicted_reward_mean + info["env_reward_mean"] = env_reward_mean + info["predicted_reward_bias"] = predicted_reward_mean - env_reward_mean + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # ---Priority diagnostics (raw + final PER priorities) --- + prio_td = diff_td_mean.squeeze(1) # (B,) + prio_r = self.scale_r * diff_reward_mean.squeeze(1) # (B,) + prio_s = self.scale_s * diff_next_state_mean.squeeze(1) # (B,) + prio_raw = prio_td + prio_r + prio_s # (B,) + + info["priority_raw_mean"] = prio_raw.mean().item() + info["priority_raw_p95"] = prio_raw.quantile(0.95).item() + + prio_den = prio_raw.mean().item() + 1e-12 + info["priority_td_frac"] = float(prio_td.mean().item() / prio_den) + info["priority_r_frac"] = float(prio_r.mean().item() / prio_den) + info["priority_s_frac"] = float(prio_s.mean().item() / prio_den) + + prio_post = prio_raw.clamp(min=self.min_priority).pow(self.per_alpha) + info["priority_mean"] = prio_post.mean().item() + info["priority_p95"] = prio_post.quantile(0.95).item() + info["priority_max"] = prio_post.max().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities def _update_actor_alpha( self, states: torch.Tensor, weights: torch.Tensor ) -> dict[str, Any]: + info: dict[str, Any] = {} + pi, log_pi, _ = self.actor_net(states) with hlp.evaluating(self.critic_net): @@ -272,6 +412,31 @@ def _update_actor_alpha( (torch.exp(self.log_alpha).detach() * log_pi - min_qf_pi) * weights ) + # --------------------------------------------------------- + # Stochastic Policy Gradient Strength (∇a [α log π(a|s) − Q(s,a)]) + # --------------------------------------------------------- + # Measures how steep the entropy-regularized critic objective is + # w.r.t. the sampled policy actions. + # + # ~0 early -> critic surface and entropy term nearly flat; + # actor receives weak learning signal. + # + # Very large -> critic or entropy term is very sharp around policy + # actions; can lead to unstable or overly aggressive + # actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=pi, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + with torch.no_grad(): + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() @@ -285,9 +450,36 @@ def _update_actor_alpha( alpha_loss.backward() self.log_alpha_optimizer.step() - info = { - "actor_loss": actor_loss.item(), - "alpha_loss": alpha_loss.item(), - } + with torch.no_grad(): + # --- Policy entropy diagnostics (exploration health) --- + # log_pi more negative -> higher entropy (more stochastic). Less negative -> lower entropy (more deterministic). + info["log_pi_mean"] = log_pi.mean().item() + info["log_pi_std"] = log_pi.std().item() + + # --- Action magnitude/saturation (tanh policies) --- + # High saturation fraction can indicate the policy is slamming bounds; may reduce effective gradients. + info["pi_action_abs_mean"] = pi.abs().mean().item() + info["pi_action_std"] = pi.std().item() + info["pi_action_saturation_frac"] = (pi.abs() > 0.95).float().mean().item() + + # --- On-policy critic signal --- + # min_qf_pi_mean should generally increase as the policy improves (higher value actions under the policy). + info["min_qf_pi_mean"] = min_qf_pi.mean().item() + + # --- Twin critics disagreement at policy actions (more relevant than replay actions) --- + # Large gap here means critics disagree on what the current policy is doing (can destabilize actor updates). + info["qf_pi_gap_abs_mean"] = (qf_pi_one - qf_pi_two).abs().mean().item() + + # --- Entropy gap (alpha tuning health) --- + # entropy_gap ~ 0 means entropy matches target. + # > 0: entropy too low -> alpha should increase; < 0: entropy too high -> alpha should decrease. + entropy_gap = -(log_pi + self.target_entropy) + info["entropy_gap_mean"] = entropy_gap.mean().item() + + # --- Losses and temperature --- + info["actor_loss"] = actor_loss.item() + info["alpha_loss"] = alpha_loss.item() + info["alpha"] = self.alpha.item() + info["log_alpha"] = self.log_alpha.item() return info diff --git a/cares_reinforcement_learning/algorithm/policy/MAPERTD3.py b/cares_reinforcement_learning/algorithm/policy/MAPERTD3.py index 1dd82592..ad312b3c 100644 --- a/cares_reinforcement_learning/algorithm/policy/MAPERTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/MAPERTD3.py @@ -127,22 +127,24 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + # Get current Q estimates output_one, output_two = self.critic_net(states, actions) - q_value_one, predicted_reward_one, next_states_one = self._split_output( + q_values_one, predicted_rewards_one, next_states_one = self._split_output( output_one ) - q_value_two, predicted_reward_two, next_states_two = self._split_output( + q_values_two, predicted_rewards_two, next_states_two = self._split_output( output_two ) # Difference in rewards diff_reward_one = 0.5 * torch.pow( - predicted_reward_one.reshape(-1, 1) - rewards.reshape(-1, 1), 2.0 + predicted_rewards_one.reshape(-1, 1) - rewards.reshape(-1, 1), 2.0 ).reshape(-1, 1) diff_reward_two = 0.5 * torch.pow( - predicted_reward_two.reshape(-1, 1) - rewards.reshape(-1, 1), 2.0 + predicted_rewards_two.reshape(-1, 1) - rewards.reshape(-1, 1), 2.0 ).reshape(-1, 1) # Difference in next states @@ -171,28 +173,32 @@ def _update_critic( next_actions = next_actions + target_noise next_actions = torch.clamp(next_actions, min=-1, max=1) - target_q_values_one, target_q_values_two = self.target_critic_net( + target_output_one, target_output_two = self.target_critic_net( next_states, next_actions ) - next_values_one, _, _ = self._split_output(target_q_values_one) - next_values_two, _, _ = self._split_output(target_q_values_two) + target_q_values_one, _, _ = self._split_output(target_output_one) + target_q_values_two, _, _ = self._split_output(target_output_two) target_q_values = torch.minimum( - next_values_one.reshape(-1, 1), next_values_two.reshape(-1, 1) + target_q_values_one.reshape(-1, 1), target_q_values_two.reshape(-1, 1) ) predicted_rewards = ( ( - predicted_reward_one.reshape(-1, 1) - + predicted_reward_two.reshape(-1, 1) + predicted_rewards_one.reshape(-1, 1) + + predicted_rewards_two.reshape(-1, 1) ) / 2 ).reshape(-1, 1) q_target = predicted_rewards + self.gamma * (1 - dones) * target_q_values - diff_td_one = F.mse_loss(q_value_one.reshape(-1, 1), q_target, reduction="none") - diff_td_two = F.mse_loss(q_value_two.reshape(-1, 1), q_target, reduction="none") + diff_td_one = F.mse_loss( + q_values_one.reshape(-1, 1), q_target, reduction="none" + ) + diff_td_two = F.mse_loss( + q_values_two.reshape(-1, 1), q_target, reduction="none" + ) critic_loss_one = ( diff_td_one @@ -255,23 +261,149 @@ def _update_critic( # Update Scales if self.learn_counter == 1: self.scale_r = np.mean(numpy_td_mean) / ( - np.mean(diff_next_state_mean_numpy) + np.mean(diff_reward_mean_numpy) + 1e-12 ) self.scale_s = np.mean(numpy_td_mean) / ( - np.mean(diff_next_state_mean_numpy) + np.mean(diff_next_state_mean_numpy) + 1e-12 + ) + + with torch.no_grad(): + # --- TD3-style smoothing diagnostics --- + # Noise diagnostics + # What it tells you: + # - target_noise_abs_mean: effective smoothing magnitude. + # - target_noise_clip_frac high early: noise often clipped (clip too small or noise too large). + target_noise_abs_mean = target_noise.abs().mean().item() + target_noise_clip_frac = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + info["target_noise_abs_mean"] = float(target_noise_abs_mean) + info["target_noise_clip_frac"] = float(target_noise_clip_frac) + + # --- Component losses (unweighted, per-sample means) --- + # These tell you whether the model heads are actually learning and on what scale. + info["td_mse_one_mean"] = diff_td_one.mean().item() + info["td_mse_two_mean"] = diff_td_two.mean().item() + info["td_mse_mean"] = ( + 0.5 * (diff_td_one.mean() + diff_td_two.mean()) + ).item() + + info["reward_pred_mse_one_mean"] = diff_reward_one.mean().item() + info["reward_pred_mse_two_mean"] = diff_reward_two.mean().item() + info["reward_pred_mse_mean"] = ( + 0.5 * (diff_reward_one.mean() + diff_reward_two.mean()) + ).item() + + info["next_state_pred_mse_one_mean"] = diff_next_states_one.mean().item() + info["next_state_pred_mse_two_mean"] = diff_next_states_two.mean().item() + info["next_state_pred_mse_mean"] = ( + 0.5 * (diff_next_states_one.mean() + diff_next_states_two.mean()) + ).item() + + # --- Scales (very important to log; they define the tradeoff) --- + info["scale_r"] = float(self.scale_r) + info["scale_s"] = float(self.scale_s) + + # --- Weighted contribution ratios (are aux losses dominating TD?) --- + # These approximate how much each term contributes inside the critic loss before IS weighting. + td_term_mean = (0.5 * (diff_td_one.mean() + diff_td_two.mean())).item() + r_term_mean = ( + 0.5 * (diff_reward_one.mean() + diff_reward_two.mean()) + ).item() + s_term_mean = ( + 0.5 * (diff_next_states_one.mean() + diff_next_states_two.mean()) + ).item() + + info["loss_term_td_mean"] = float(td_term_mean) + info["loss_term_r_scaled_mean"] = float(self.scale_r * r_term_mean) + info["loss_term_s_scaled_mean"] = float(self.scale_s * s_term_mean) + + den = ( + td_term_mean + + (self.scale_r * r_term_mean) + + (self.scale_s * s_term_mean) + + 1e-12 + ) + info["loss_td_frac"] = float(td_term_mean / den) + info["loss_r_frac"] = float((self.scale_r * r_term_mean) / den) + info["loss_s_frac"] = float((self.scale_s * s_term_mean) / den) + + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Predicted reward bias (MaPER/MAPERTD3-specific) --- + # Since the Bellman target uses predicted reward instead of env reward, + # any systematic bias here directly shifts Q-targets and can cause + # value inflation or suppression. + predicted_reward_mean = predicted_rewards.mean().item() + env_reward_mean = rewards.mean().item() + + info["predicted_reward_mean"] = predicted_reward_mean + info["env_reward_mean"] = env_reward_mean + info["predicted_reward_bias"] = predicted_reward_mean - env_reward_mean + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # ---Priority diagnostics (raw + final PER priorities) --- + prio_td = diff_td_mean.squeeze(1) # (B,) + prio_r = self.scale_r * diff_reward_mean.squeeze(1) # (B,) + prio_s = self.scale_s * diff_next_state_mean.squeeze(1) # (B,) + prio_raw = prio_td + prio_r + prio_s # (B,) + + info["priority_raw_mean"] = prio_raw.mean().item() + info["priority_raw_p95"] = prio_raw.quantile(0.95).item() + + prio_den = prio_raw.mean().item() + 1e-12 + info["priority_td_frac"] = float(prio_td.mean().item() / prio_den) + info["priority_r_frac"] = float(prio_r.mean().item() / prio_den) + info["priority_s_frac"] = float(prio_s.mean().item() / prio_den) + + prio_post = prio_raw.clamp(min=self.min_priority).pow(self.per_alpha) + info["priority_mean"] = prio_post.mean().item() + info["priority_p95"] = prio_post.quantile(0.95).item() + info["priority_max"] = prio_post.max().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities def _update_actor( self, states: torch.Tensor, weights: torch.Tensor ) -> dict[str, Any]: + info: dict[str, Any] = {} + actions = self.actor_net(states.detach()) with hlp.evaluating(self.critic_net): @@ -280,16 +412,50 @@ def _update_actor( actor_q_one, _, _ = self._split_output(output_one) actor_q_two, _, _ = self._split_output(output_two) - actor_val = torch.minimum(actor_q_one, actor_q_two) - - actor_loss = -(actor_val * weights).mean() + actor_q_values = torch.minimum(actor_q_one, actor_q_two) + + actor_loss = -(actor_q_values * weights).mean() + + # --------------------------------------------------------- + # Deterministic Policy Gradient Strength (∇a Q(s,a)) + # --------------------------------------------------------- + # Measures how steep the critic surface is w.r.t. actions. + # ~0 early -> critic flat, actor receives no learning signal. + # Very large -> critic overly sharp, can cause unstable actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=actions, + retain_graph=True, # because we do backward(actor_loss) next + create_graph=False, # diagnostic only + allow_unused=False, + )[0] + with torch.no_grad(): + # - ~0 early: critic surface flat around actor actions (weak learning signal) + # - very large: critic surface sharp -> unstable / exploitative actor updates + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() # Optimize the actor self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() - info = { - "actor_loss": actor_loss.item(), - } + with torch.no_grad(): + # Policy Action Health (tanh policies in [-1, 1]) + # pi_action_saturation_frac: + # High values (>0.8 early) often mean the actor is slamming bounds, + # reducing effective gradient flow through tanh. + info["pi_action_mean"] = actions.mean().item() + info["pi_action_std"] = actions.std().item() + info["pi_action_abs_mean"] = actions.abs().mean().item() + info["pi_action_saturation_frac"] = ( + (actions.abs() > 0.95).float().mean().item() + ) + + # actor_q_mean should generally increase over training. + # actor_q_std large + unstable may indicate critic inconsistency. + info["actor_loss"] = actor_loss.item() + info["actor_q_mean"] = actor_q_values.mean().item() + info["actor_q_std"] = actor_q_values.std().item() return info diff --git a/cares_reinforcement_learning/algorithm/policy/MAPPO.py b/cares_reinforcement_learning/algorithm/policy/MAPPO.py index bcf8fcec..882bfc7a 100644 --- a/cares_reinforcement_learning/algorithm/policy/MAPPO.py +++ b/cares_reinforcement_learning/algorithm/policy/MAPPO.py @@ -71,7 +71,7 @@ SARLObservation, ) from cares_reinforcement_learning.util.configurations import MAPPOConfig -from cares_reinforcement_learning.util.helpers import EpsilonScheduler +from cares_reinforcement_learning.util.helpers import ExponentialScheduler class MAPPO(MARLAlgorithm[list[np.ndarray]]): @@ -90,13 +90,13 @@ def __init__( self.minibatch_size = config.minibatch_size self.updates_per_iteration = config.updates_per_iteration - self.epsilon_scheduler = EpsilonScheduler( - start_epsilon=config.entropy_start, - end_epsilon=config.entropy_end, + self.entropy_scheduler = ExponentialScheduler( + start_value=config.entropy_start, + end_value=config.entropy_end, decay_steps=config.entropy_decay, ) # initial entropy coefficient - self.entropy_coef = self.epsilon_scheduler.get_epsilon(0) + self.entropy_coef = self.entropy_scheduler.get_value(0) self.target_kl = config.target_kl @@ -150,12 +150,12 @@ def train( info: dict[str, Any] = {} - self.entropy_coef = self.epsilon_scheduler.get_epsilon( + self.entropy_coef = self.entropy_scheduler.get_value( episode_context.training_step ) # --------------------------------------------------------- - # Sample ONCE for all agents (recommended for TD3/SAC) + # Sample ONCE for all agents (recommended for PPO/TD3/SAC) # Shared minibatch: We draw one minibatch per training iteration and reuse it across agent updates. # This preserves an unbiased estimator of each update while reducing sampling-induced variance and # keeping joint transitions consistent for centralized critics. @@ -235,26 +235,7 @@ def train( critic_loss_sum = 0.0 num_critic_mb = 0 - agent_sums = [ - { - k: 0.0 - for k in [ - "actor_loss", - "entropy", - "approx_kl", - "clip_frac", - "ratio_mean", - "ratio_std", - "action_sat_rate", - "u_abs_mean", - "u_abs_max", - "log_ratio_mean", - "log_ratio_std", - "log_ratio_max_abs", - ] - } - for _ in range(self.num_agents) - ] + agent_actor_sums: list[dict[str, float]] = [{} for _ in range(self.num_agents)] agent_max_kl = [0.0 for _ in range(self.num_agents)] # minibatches that actually updated (for averaging stats) agent_updates = [0 for _ in range(self.num_agents)] @@ -289,8 +270,10 @@ def train( # Accumulate only if update happened if not kl_early_stop: agent_updates[agent_idx] += 1 - for k in agent_sums[agent_idx].keys(): - agent_sums[agent_idx][k] += float(actor_info[k]) + for k in actor_info.keys(): + if k not in agent_actor_sums[agent_idx]: + agent_actor_sums[agent_idx][k] = 0.0 + agent_actor_sums[agent_idx][k] += float(actor_info[k]) # Track max KL seen regardless (if target_kl is enabled, approx_kl is meaningful) if self.target_kl is not None: @@ -325,41 +308,61 @@ def train( td_err = returns_all - values # [T, N] for i in range(self.num_agents): - info[f"agent{i}_adv_mean"] = float(advantages_all[:, i].mean().item()) - info[f"agent{i}_adv_std"] = float( + info[f"agent_{i}_adv_mean"] = float(advantages_all[:, i].mean().item()) + info[f"agent_{i}_adv_std"] = float( advantages_all[:, i].std(unbiased=False).item() ) - info[f"agent{i}_ret_mean"] = float(returns_all[:, i].mean().item()) - info[f"agent{i}_ret_std"] = float( + info[f"agent_{i}_ret_mean"] = float(returns_all[:, i].mean().item()) + info[f"agent_{i}_ret_std"] = float( returns_all[:, i].std(unbiased=False).item() ) - info[f"agent{i}_v_mean"] = float(values[:, i].mean().item()) - info[f"agent{i}_v_std"] = float(values[:, i].std(unbiased=False).item()) + info[f"agent_{i}_v_mean"] = float(values[:, i].mean().item()) + info[f"agent_{i}_v_std"] = float( + values[:, i].std(unbiased=False).item() + ) - info[f"agent{i}_td_mean"] = float(td_err[:, i].mean().item()) - info[f"agent{i}_td_std"] = float( + info[f"agent_{i}_td_mean"] = float(td_err[:, i].mean().item()) + info[f"agent_{i}_td_std"] = float( td_err[:, i].std(unbiased=False).item() ) - info[f"agent{i}_td_mae"] = float(td_err[:, i].abs().mean().item()) + info[f"agent_{i}_td_mae"] = float(td_err[:, i].abs().mean().item()) y = returns_all[:, i] yhat = values[:, i] var_y = torch.var(y, unbiased=False) ev = 1.0 - torch.var(y - yhat, unbiased=False) / (var_y + 1e-8) - info[f"agent{i}_explained_var"] = float(ev.item()) + info[f"agent_{i}_explained_var"] = float(ev.item()) denom = max(agent_updates[i], 1) - info[f"agent{i}_actor_updates"] = int(agent_updates[i]) - info[f"agent{i}_kl_early_stop"] = int(agent_kl_early_stop[i]) + info[f"agent_{i}_actor_updates"] = int(agent_updates[i]) + info[f"agent_{i}_kl_early_stop"] = int(agent_kl_early_stop[i]) - for k, v in agent_sums[i].items(): - info[f"agent{i}_{k}"] = v / denom + for k, v in agent_actor_sums[i].items(): + info[f"agent_{i}_{k}"] = v / denom if self.target_kl is not None: - info[f"agent{i}_max_kl_seen"] = agent_max_kl[i] + info[f"agent_{i}_max_kl_seen"] = agent_max_kl[i] + + for k in agent_actor_sums[0].keys(): + values = [info[f"agent_{i}_{k}"] for i in range(self.num_agents)] + info[f"mean_{k}"] = float(np.mean(values)) + + for metric in [ + "ret_mean", + "ret_std", + "v_mean", + "v_std", + "td_mae", + "explained_var", + ]: + values = [info[f"agent_{i}_{metric}"] for i in range(self.num_agents)] + info[f"mean_{metric}"] = float(np.mean(values)) + info[f"std_{metric}"] = float(np.std(values)) + info[f"max_{metric}"] = float(np.max(values)) + info[f"min_{metric}"] = float(np.min(values)) stopped = sum(int(x) for x in agent_kl_early_stop) info["num_agents_kl_stopped"] = stopped diff --git a/cares_reinforcement_learning/algorithm/policy/MASAC.py b/cares_reinforcement_learning/algorithm/policy/MASAC.py index f46398c2..0056e179 100644 --- a/cares_reinforcement_learning/algorithm/policy/MASAC.py +++ b/cares_reinforcement_learning/algorithm/policy/MASAC.py @@ -39,6 +39,7 @@ import torch.nn.functional as F import cares_reinforcement_learning.memory.memory_sampler as memory_sampler +import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.algorithm.algorithm import MARLAlgorithm from cares_reinforcement_learning.algorithm.policy.SAC import SAC from cares_reinforcement_learning.memory.memory_buffer import MARLMemoryBuffer @@ -110,13 +111,14 @@ def _update_critic( next_logp_i: torch.Tensor, # (B, 1) log pi_i(a_i' | o_i') for NEXT state dones_i: torch.Tensor, # (B,) or (B, 1) ): + info: dict[str, Any] = {} # ---- Step 1: TD target with entropy term ---- with torch.no_grad(): - target_q1, target_q2 = agent.target_critic_net( + target_q_values_one, target_q_values_two = agent.target_critic_net( next_global_states, next_joint_actions ) - target_q = torch.min(target_q1, target_q2) + target_q = torch.min(target_q_values_one, target_q_values_two) q_target = rewards_i + self.gamma * (1.0 - dones_i) * ( target_q - agent.alpha * next_logp_i @@ -140,11 +142,68 @@ def _update_critic( agent.critic_net_optimiser.step() - return { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + # --------------------------------------------------------- + # Step 3: diagnostics (collated at bottom) + # --------------------------------------------------------- + with torch.no_grad(): + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Soft target decomposition (SAC-specific) --- + # min_target_q_mean: the conservative bootstrap value from twin critics (pre-entropy) + # entropy_term_mean: magnitude of entropy regularization in the target (alpha * log_pi is usually negative) + # soft_target_value_mean: the exact term used inside the Bellman target before reward/discount + min_target_q = torch.minimum(target_q_values_one, target_q_values_two) + + # alpha_log_pi is typically negative; entropy_bonus is typically positive + alpha_log_pi = agent.alpha * next_logp_i + # this is what gets ADDED to minQ in the target + entropy_bonus = -agent.alpha * next_logp_i + + soft_target_value = min_target_q + entropy_bonus # == minQ - alpha*log_pi + + info["target_min_q_mean"] = min_target_q.mean().item() + info["alpha_log_pi_mean"] = alpha_log_pi.mean().item() + info["entropy_bonus_mean"] = entropy_bonus.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() + + return info def _update_actor_alpha( self, @@ -154,6 +213,8 @@ def _update_actor_alpha( global_states: torch.Tensor, current_actions_tensor: torch.Tensor, # (B, N, act_dim) sampled under no_grad ): + info: dict[str, Any] = {} + agent_ids = list(obs_tensors.keys()) batch_size = global_states.shape[0] @@ -163,20 +224,46 @@ def _update_actor_alpha( # Sample CURRENT actions for all agents (detach others) # --------------------------------------------------------- obs_i = obs_tensors[agent_ids[agent_index]] - action_i, logp_i, _ = agent.actor_net(obs_i) # grads for i only + pi_i, log_pi_i, _ = agent.actor_net(obs_i) # grads for i only - actions_all[:, agent_index, :] = action_i # only i is live + actions_all[:, agent_index, :] = pi_i # only i is live joint_actions_flat = actions_all.reshape(batch_size, -1) # --------------------------------------------------------- # Step 4: Compute actor loss: -Q_i(x, a_1,...,a_i,...,a_N) # --------------------------------------------------------- - q_val_one, q_val_two = agent.critic_net(global_states, joint_actions_flat) + with hlp.evaluating(agent.critic_net): + qf_pi_one, q_pi_two = agent.critic_net(global_states, joint_actions_flat) - q_val = torch.min(q_val_one, q_val_two) + min_qf_pi = torch.min(qf_pi_one, q_pi_two) - actor_loss = (agent.alpha * logp_i - q_val).mean() + actor_loss = (agent.alpha * log_pi_i - min_qf_pi).mean() + + # --------------------------------------------------------- + # Stochastic Policy Gradient Strength (∇a [α log π(a|s) − Q(s,a)]) + # --------------------------------------------------------- + # Measures how steep the entropy-regularized critic objective is + # w.r.t. the sampled policy actions. + # + # ~0 early -> critic surface and entropy term nearly flat; + # actor receives weak learning signal. + # + # Very large -> critic or entropy term is very sharp around policy + # actions; can lead to unstable or overly aggressive + # actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=pi_i, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + with torch.no_grad(): + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() # --------------------------------------------------------- # Step 5: Backprop @@ -195,18 +282,48 @@ def _update_actor_alpha( # Step 6: Alpha loss and update # --------------------------------------------------------- alpha_loss = -( - agent.log_alpha * (logp_i + agent.target_entropy).detach() + agent.log_alpha * (log_pi_i + agent.target_entropy).detach() ).mean() agent.log_alpha_optimizer.zero_grad(set_to_none=True) alpha_loss.backward() agent.log_alpha_optimizer.step() - return { - "actor_loss": actor_loss.item(), - "alpha_loss": alpha_loss.item(), - "alpha": agent.alpha.item(), - } + with torch.no_grad(): + # --- Policy entropy diagnostics (exploration health) --- + # log_pi more negative -> higher entropy (more stochastic). Less negative -> lower entropy (more deterministic). + info["log_pi_mean"] = log_pi_i.mean().item() + info["log_pi_std"] = log_pi_i.std().item() + + # --- Action magnitude/saturation (tanh policies) --- + # High saturation fraction can indicate the policy is slamming bounds; may reduce effective gradients. + info["pi_action_abs_mean"] = pi_i.abs().mean().item() + info["pi_action_std"] = pi_i.std().item() + info["pi_action_saturation_frac"] = ( + (pi_i.abs() > 0.95).float().mean().item() + ) + + # --- On-policy critic signal --- + # min_qf_pi_mean should generally increase as the policy improves (higher value actions under the policy). + info["min_qf_pi_mean"] = min_qf_pi.mean().item() + + # --- Twin critics disagreement at policy actions (more relevant than replay actions) --- + # Large gap here means critics disagree on what the current policy is doing (can destabilize actor updates). + info["qf_pi_gap_abs_mean"] = (qf_pi_one - q_pi_two).abs().mean().item() + + # --- Entropy gap (alpha tuning health) --- + # entropy_gap ~ 0 means entropy matches target. + # > 0: entropy too low -> alpha should increase; < 0: entropy too high -> alpha should decrease. + entropy_gap = -(log_pi_i + agent.target_entropy) + info["entropy_gap_mean"] = entropy_gap.mean().item() + + # --- Losses and temperature --- + info["actor_loss"] = actor_loss.item() + info["alpha_loss"] = alpha_loss.item() + info["alpha"] = agent.alpha.item() + info["log_alpha"] = agent.log_alpha.item() + + return info def train( self, @@ -257,17 +374,18 @@ def train( # Build NEXT actions by sampling from CURRENT policies (not target) # --------------------------------------------------------- # In SAC, targets use a' ~ pi(·|o') (reparameterized), then evaluate target critics. + # Computing next_joint_actions once outside ensures every agent sees the same bootstrapping sample for that minibatch. next_actions = [] next_logps = [] for agent, agent_id in zip(self.agent_networks, agent_ids): obs_next = next_agent_states[agent_id] - # You need a method like: action, logp = agent.actor_net.sample(obs_next) - # where action is already in env action space (or consistently scaled) - next_action_j, next_logp_j, _ = agent.actor_net( - obs_next - ) # (B, act_dim), (B, 1) + with torch.no_grad(): + with hlp.evaluating(agent.actor_net): + next_action_j, next_logp_j, _ = agent.actor_net( + obs_next + ) # (B, act_dim), (B, 1) next_actions.append(next_action_j) next_logps.append(next_logp_j) @@ -305,7 +423,8 @@ def train( # --------------------------------------------------------- # ACTOR + ALPHA UPDATES — usually every step in SAC # --------------------------------------------------------- - if self.learn_counter % self.policy_update_freq == 0: + update_actor = self.learn_counter % self.policy_update_freq == 0 + if update_actor: # --------------------------------------------------------- # For MASAC, we sample current actions from all agents when # computing each agent’s actor loss, detaching other agents’ samples. @@ -339,6 +458,54 @@ def train( for agent in self.agent_networks: agent.update_target_networks() + with torch.no_grad(): + # --- Joint action distribution (from replay) --- + # Detects action collapse / saturation / scaling issues across agents + info["joint_action_abs_mean"] = actions_tensor.abs().mean().item() + info["joint_action_std"] = actions_tensor.std(unbiased=False).item() + info["action_saturation_frac"] = ( + (actions_tensor.abs() > 0.95).float().mean().item() + ) + + # --- Coordination proxy on replay actions --- + # Cos similarity between agents' action vectors per sample (mean over pairs) + a = actions_tensor # (B,N,A) + a_norm = a / (a.norm(dim=2, keepdim=True) + 1e-12) + cos = torch.einsum("bna,bma->bnm", a_norm, a_norm) # (B,N,N) + n = cos.shape[1] + mask = ~torch.eye(n, device=cos.device, dtype=torch.bool) + info["replay_action_cos_mean"] = cos[:, mask].mean().item() + info["replay_action_cos_p95"] = cos[:, mask].quantile(0.95).item() + + # --- Next-policy sampling health (SAC target actions) --- + # These catch entropy collapse and alpha/logp pathologies early + info["next_logp_mean_all_agents"] = next_logps_tensor.mean().item() + info["next_logp_std_all_agents"] = next_logps_tensor.std( + unbiased=False + ).item() + info["next_entropy_mean_all_agents"] = (-next_logps_tensor).mean().item() + + info["next_action_abs_mean_all_agents"] = ( + next_actions_tensor.abs().mean().item() + ) + info["next_action_std_all_agents"] = next_actions_tensor.std( + unbiased=False + ).item() + info["next_action_saturation_frac_all_agents"] = ( + (next_actions_tensor.abs() > 0.95).float().mean().item() + ) + + # --- Cross-agent diagnostics --- + metrics = list(critic_info.keys()) + if update_actor: + metrics += list(actor_info.keys()) + for metric in metrics: + values = [info[f"agent_{i}_{metric}"] for i in range(self.num_agents)] + info[f"mean_{metric}"] = float(np.mean(values)) + info[f"std_{metric}"] = float(np.std(values)) + info[f"max_{metric}"] = float(np.max(values)) + info[f"min_{metric}"] = float(np.min(values)) + return info def save_models(self, filepath: str, filename: str) -> None: diff --git a/cares_reinforcement_learning/algorithm/policy/MATD3.py b/cares_reinforcement_learning/algorithm/policy/MATD3.py index d3c73a07..0efa8a31 100644 --- a/cares_reinforcement_learning/algorithm/policy/MATD3.py +++ b/cares_reinforcement_learning/algorithm/policy/MATD3.py @@ -39,6 +39,7 @@ import torch.nn.functional as F import cares_reinforcement_learning.memory.memory_sampler as memory_sampler +import cares_reinforcement_learning.util.helpers as hlp from cares_reinforcement_learning.algorithm.algorithm import MARLAlgorithm from cares_reinforcement_learning.algorithm.policy.TD3 import TD3 from cares_reinforcement_learning.memory.memory_buffer import MARLMemoryBuffer @@ -49,6 +50,7 @@ SARLObservation, ) from cares_reinforcement_learning.util.configurations import MATD3Config +from cares_reinforcement_learning.util.helpers import ExponentialScheduler class MATD3(MARLAlgorithm[list[np.ndarray]]): @@ -68,7 +70,14 @@ def __init__( self.policy_update_freq = config.policy_update_freq - self.policy_noise = config.policy_noise + # Policy noise + self.policy_noise_clip = config.policy_noise_clip + self.policy_noise_scheduler = ExponentialScheduler( + start_value=config.policy_noise_start, + end_value=config.policy_noise_end, + decay_steps=config.policy_noise_decay, + ) + self.policy_noise = self.policy_noise_scheduler.get_value(0) self.policy_noise_clip = config.policy_noise_clip self.max_grad_norm = config.max_grad_norm @@ -111,6 +120,7 @@ def _update_critic( next_actions_tensor: torch.Tensor, # (B, N, act_dim) from target actors dones_i: torch.Tensor, ): + info: dict[str, Any] = {} # --- Step 1: build next joint actions --- next_joint_actions = next_actions_tensor.view(next_actions_tensor.size(0), -1) @@ -140,11 +150,47 @@ def _update_critic( agent.critic_net_optimiser.step() - return { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() + + return info def _update_actor( self, @@ -159,6 +205,7 @@ def _update_actor( - For j ≠ agent_index: use replay-buffer actions - For j == agent_index: use current actor output """ + info: dict[str, Any] = {} agent_ids = list(obs_tensors.keys()) batch_size = global_states.shape[0] @@ -173,20 +220,37 @@ def _update_actor( # Step 2: Replace ONLY agent_i action with differentiable action # --------------------------------------------------------- obs_i = obs_tensors[agent_ids[agent_index]] # (B, obs_dim_i) - pred_action_i = agent.actor_net(obs_i) # differentiable + actions_i = agent.actor_net(obs_i) # differentiable - actions_all[:, agent_index, :] = pred_action_i # keep others from buffer + actions_all[:, agent_index, :] = actions_i # keep others from buffer # --------------------------------------------------------- # Step 4: Compute actor loss: -Q_i(x, a_1,...,a_i,...,a_N) # --------------------------------------------------------- joint_actions_flat = actions_all.reshape(batch_size, -1) - q_val, _ = agent.critic_net(global_states, joint_actions_flat) + actor_q_values, _ = agent.critic_net(global_states, joint_actions_flat) - # regularization as in TF code - reg = (pred_action_i**2).mean() * 1e-3 + actor_loss = -actor_q_values.mean() - actor_loss = -q_val.mean() + reg + # --------------------------------------------------------- + # Deterministic Policy Gradient Strength (∇a Q(s,a)) + # --------------------------------------------------------- + # Measures how steep the critic surface is w.r.t. actions. + # ~0 early -> critic flat, actor receives no learning signal. + # Very large -> critic overly sharp, can cause unstable actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=actions_i, + retain_graph=True, # because we do backward(actor_loss) next + create_graph=False, # diagnostic only + allow_unused=False, + )[0] + with torch.no_grad(): + # - ~0 early: critic surface flat around actor actions (weak learning signal) + # - very large: critic surface sharp -> unstable / exploitative actor updates + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() # --------------------------------------------------------- # Step 5: Backprop @@ -201,7 +265,25 @@ def _update_actor( agent.actor_net_optimiser.step() - return {"actor_loss": actor_loss.item()} + with torch.no_grad(): + # Policy Action Health (tanh policies in [-1, 1]) + # pi_action_saturation_frac: + # High values (>0.8 early) often mean the actor is slamming bounds, + # reducing effective gradient flow through tanh. + info["pi_action_mean"] = actions_i.mean().item() + info["pi_action_std"] = actions_i.std().item() + info["pi_action_abs_mean"] = actions_i.abs().mean().item() + info["pi_action_saturation_frac"] = ( + (actions_i.abs() > 0.95).float().mean().item() + ) + + # actor_q_mean should generally increase over training. + # actor_q_std large + unstable may indicate critic inconsistency. + info["actor_loss"] = actor_loss.item() + info["actor_q_mean"] = actor_q_values.mean().item() + info["actor_q_std"] = actor_q_values.std().item() + + return info def train( self, @@ -213,6 +295,19 @@ def train( info: dict[str, Any] = {} + # Update per agent action noise for exploration (decayed over training) + for i, current_agent in enumerate(self.agent_networks): + current_agent.action_noise = current_agent.action_noise_scheduler.get_value( + episode_context.training_step + ) + info[f"agent_{i}_current_action_noise"] = current_agent.action_noise + + # Update TD3 target policy smoothing noise (decayed over training) + self.policy_noise = self.policy_noise_scheduler.get_value( + episode_context.training_step + ) + info["current_policy_noise"] = self.policy_noise + # --------------------------------------------------------- # Sample ONCE for all agents (recommended for TD3/SAC) # Shared minibatch: We draw one minibatch per training iteration and reuse it across agent updates. @@ -254,7 +349,8 @@ def train( next_actions = [] for agent, agent_id in zip(self.agent_networks, agent_ids): obs_next = next_agent_states[agent_id] - next_actions.append(agent.target_actor_net(obs_next)) + with hlp.evaluating(agent.target_actor_net): + next_actions.append(agent.target_actor_net(obs_next)) # (B, N, act_dim) next_actions_tensor = torch.stack(next_actions, dim=1) @@ -263,11 +359,25 @@ def train( # TD3 TARGET POLICY SMOOTHING (ONCE) # --------------------------------------------------------- # This affects ONLY critic targets - noise = torch.randn_like(next_actions_tensor) * self.policy_noise - noise = noise.clamp(-self.policy_noise_clip, self.policy_noise_clip) + target_noise = torch.randn_like(next_actions_tensor) * self.policy_noise + target_noise = target_noise.clamp( + -self.policy_noise_clip, self.policy_noise_clip + ) + + # --- TD3-style smoothing diagnostics --- + # Noise diagnostics + # What it tells you: + # - target_noise_abs_mean: effective smoothing magnitude. + # - target_noise_clip_frac high early: noise often clipped (clip too small or noise too large). + target_noise_abs_mean = target_noise.abs().mean().item() + target_noise_clip_frac = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + info["target_noise_abs_mean"] = float(target_noise_abs_mean) + info["target_noise_clip_frac"] = float(target_noise_clip_frac) # assumes tanh policy -> [-1, 1] - next_actions_noisy = (next_actions_tensor + noise).clamp(-1.0, 1.0) + next_actions_noisy = (next_actions_tensor + target_noise).clamp(-1.0, 1.0) # --------------------------------------------------------- # CRITIC UPDATES (every step) @@ -291,7 +401,8 @@ def train( # --------------------------------------------------------- # ACTOR + TARGET UPDATES (DELAYED — TD3) # --------------------------------------------------------- - if self.learn_counter % self.policy_update_freq == 0: + update_actor = self.learn_counter % self.policy_update_freq == 0 + if update_actor: for agent_index, agent in enumerate(self.agent_networks): actor_info = self._update_actor( agent=agent, @@ -307,6 +418,46 @@ def train( for agent in self.agent_networks: agent.update_target_networks() + with torch.no_grad(): + # --- Joint action distribution (from replay) --- + # Detects action collapse / saturation / scaling issues across agents + info["joint_action_abs_mean"] = actions_tensor.abs().mean().item() + info["joint_action_std"] = actions_tensor.std(unbiased=False).item() + info["action_saturation_frac"] = ( + (actions_tensor.abs() > 0.95).float().mean().item() + ) + + # --- Coordination proxy on replay actions --- + # Cos similarity between agents' action vectors per sample (mean over pairs) + a = actions_tensor # (B,N,A) + a_norm = a / (a.norm(dim=2, keepdim=True) + 1e-12) + cos = torch.einsum("bna,bma->bnm", a_norm, a_norm) # (B,N,N) + n = cos.shape[1] + mask = ~torch.eye(n, device=cos.device, dtype=torch.bool) + info["replay_action_cos_mean"] = cos[:, mask].mean().item() + info["replay_action_cos_p95"] = cos[:, mask].quantile(0.95).item() + + info["next_action_abs_mean_all_agents"] = ( + next_actions_tensor.abs().mean().item() + ) + info["next_action_std_all_agents"] = next_actions_tensor.std( + unbiased=False + ).item() + info["next_action_saturation_frac_all_agents"] = ( + (next_actions_tensor.abs() > 0.95).float().mean().item() + ) + + # --- Cross-agent diagnostics --- + metrics = list(critic_info.keys()) + if update_actor: + metrics += list(actor_info.keys()) + for metric in metrics: + values = [info[f"agent_{i}_{metric}"] for i in range(self.num_agents)] + info[f"mean_{metric}"] = float(np.mean(values)) + info[f"std_{metric}"] = float(np.std(values)) + info[f"max_{metric}"] = float(np.max(values)) + info[f"min_{metric}"] = float(np.min(values)) + return info def save_models(self, filepath: str, filename: str) -> None: diff --git a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py index e1365a7c..98dc7b0a 100644 --- a/cares_reinforcement_learning/algorithm/policy/NaSATD3.py +++ b/cares_reinforcement_learning/algorithm/policy/NaSATD3.py @@ -106,6 +106,7 @@ SARLObservationTensors, ) from cares_reinforcement_learning.util.configurations import NaSATD3Config +from cares_reinforcement_learning.util.helpers import ExponentialScheduler class NaSATD3(SARLAlgorithm[np.ndarray]): @@ -128,36 +129,41 @@ def __init__( self.policy_update_freq = config.policy_update_freq # Policy noise - self.min_policy_noise = config.min_policy_noise - self.policy_noise = config.policy_noise - self.policy_noise_decay = config.policy_noise_decay - self.policy_noise_clip = config.policy_noise_clip + self.policy_noise_scheduler = ExponentialScheduler( + start_value=config.policy_noise_start, + end_value=config.policy_noise_end, + decay_steps=config.policy_noise_decay, + ) + self.policy_noise = self.policy_noise_scheduler.get_value(0) # Action noise - self.min_action_noise = config.min_action_noise - self.action_noise = config.action_noise - self.action_noise_decay = config.action_noise_decay + self.action_noise_scheduler = ExponentialScheduler( + start_value=config.action_noise_start, + end_value=config.action_noise_end, + decay_steps=config.action_noise_decay, + ) + self.action_noise = self.action_noise_scheduler.get_value(0) # Doesn't matter which autoencoder is used, as long as it is the same for all networks self.autoencoder: VanillaAutoencoder | BurgessAutoencoder = ( actor_network.autoencoder ) - self.actor = actor_network.to(device) - self.critic = critic_network.to(device) + self.actor_net = actor_network.to(device) + self.critic_net = critic_network.to(device) - self.actor_target = copy.deepcopy(self.actor).to(device) - self.critic_target = copy.deepcopy(self.critic).to(device) + self.actor_target = copy.deepcopy(self.actor_net).to(device) + self.critic_target = copy.deepcopy(self.critic_net).to(device) # Necessary to make the same autoencoder in the whole algorithm - self.actor.autoencoder = self.autoencoder - self.critic.autoencoder = self.autoencoder + self.actor_net.autoencoder = self.autoencoder + self.critic_net.autoencoder = self.autoencoder self.actor_target.autoencoder = self.autoencoder self.critic_target.autoencoder = self.autoencoder - self.action_num = self.actor.num_actions + self.action_num = self.actor_net.num_actions self.ensemble_predictive_model = nn.ModuleList() networks = [ @@ -170,10 +176,10 @@ def __init__( self.actor_lr = config.actor_lr self.critic_lr = config.critic_lr self.actor_optimizer = torch.optim.Adam( - self.actor.parameters(), lr=self.actor_lr + self.actor_net.parameters(), lr=self.actor_lr ) self.critic_optimizer = torch.optim.Adam( - self.critic.parameters(), lr=self.critic_lr + self.critic_net.parameters(), lr=self.critic_lr ) self.epm_lr = config.epm_lr @@ -191,7 +197,7 @@ def act( observation: SARLObservation, evaluation: bool = False, ) -> ActionSample[np.ndarray]: - self.actor.eval() + self.actor_net.eval() self.autoencoder.eval() with torch.no_grad(): @@ -199,7 +205,7 @@ def act( [observation], self.device ) - action = self.actor(observation_tensors) + action = self.actor_net(observation_tensors) action = action.cpu().data.numpy().flatten() if not evaluation: # this is part the TD3 too, add noise to the action @@ -210,7 +216,7 @@ def act( action = action + noise action = np.clip(action, -1, 1) - self.actor.train() + self.actor_net.train() self.autoencoder.train() return ActionSample(action=action, source="policy") @@ -221,7 +227,9 @@ def _update_critic( rewards: torch.Tensor, next_states: SARLObservationTensors, dones: torch.Tensor, - ) -> tuple[float, float, float]: + ) -> dict[str, Any]: + info: dict[str, Any] = {} + with torch.no_grad(): next_actions = self.actor_target(next_states) target_noise = self.policy_noise * torch.randn_like(next_actions) @@ -238,7 +246,7 @@ def _update_critic( q_target = rewards + self.gamma * (1 - dones) * target_q_values - q_values_one, q_values_two = self.critic(states, actions) + q_values_one, q_values_two = self.critic_net(states, actions) critic_loss_one = F.mse_loss(q_values_one, q_target) critic_loss_two = F.mse_loss(q_values_two, q_target) @@ -248,25 +256,121 @@ def _update_critic( critic_loss_total.backward() self.critic_optimizer.step() - return critic_loss_one.item(), critic_loss_two.item(), critic_loss_total.item() + with torch.no_grad(): + # --- TD3-style smoothing diagnostics --- + # Noise diagnostics + # What it tells you: + # - target_noise_abs_mean: effective smoothing magnitude. + # - target_noise_clip_frac high early: noise often clipped (clip too small or noise too large). + target_noise_abs_mean = target_noise.abs().mean().item() + target_noise_clip_frac = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + info["target_noise_abs_mean"] = float(target_noise_abs_mean) + info["target_noise_clip_frac"] = float(target_noise_clip_frac) + + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() + + return info def _update_autoencoder(self, states: torch.Tensor) -> float: # Leaving this function in case this needs to be extended again in the future ae_loss = self.autoencoder.update_autoencoder(states) return ae_loss.item() - def _update_actor(self, states: SARLObservationTensors) -> float: - actor_q_one, actor_q_two = self.critic( - states, self.actor(states, detach_encoder=True), detach_encoder=True - ) - actor_q_values = torch.minimum(actor_q_one, actor_q_two) + def _update_actor(self, states: SARLObservationTensors) -> dict[str, Any]: + info: dict[str, Any] = {} + + actions = self.actor_net(states, detach_encoder=True) + + with hlp.evaluating(self.critic_net): + actor_q_values_one, actor_q_values_two = self.critic_net( + states, actions, detach_encoder=True + ) + actor_q_values = torch.minimum(actor_q_values_one, actor_q_values_two) + actor_loss = -actor_q_values.mean() + # --------------------------------------------------------- + # Deterministic Policy Gradient Strength (∇a Q(s,a)) + # --------------------------------------------------------- + # Measures how steep the critic surface is w.r.t. actions. + # ~0 early -> critic flat, actor receives no learning signal. + # Very large -> critic overly sharp, can cause unstable actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=actions, + retain_graph=True, # because we do backward(actor_loss) next + create_graph=False, # diagnostic only + allow_unused=False, + )[0] + with torch.no_grad(): + # - ~0 early: critic surface flat around actor actions (weak learning signal) + # - very large: critic surface sharp -> unstable / exploitative actor updates + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() - return actor_loss.item() + with torch.no_grad(): + # Policy Action Health (tanh policies in [-1, 1]) + # pi_action_saturation_frac: + # High values (>0.8 early) often mean the actor is slamming bounds, + # reducing effective gradient flow through tanh. + info["pi_action_mean"] = actions.mean().item() + info["pi_action_std"] = actions.std().item() + info["pi_action_abs_mean"] = actions.abs().mean().item() + info["pi_action_saturation_frac"] = ( + (actions.abs() > 0.95).float().mean().item() + ) + + # actor_q_mean should generally increase over training. + # actor_q_std large + unstable may indicate critic inconsistency. + info["actor_loss"] = actor_loss.item() + info["actor_q_mean"] = actor_q_values.mean().item() + info["actor_q_std"] = actor_q_values.std().item() + + return info def _get_latent_state( self, states: torch.Tensor, detach_output: bool, sample_latent: bool = True @@ -325,19 +429,21 @@ def train( memory_buffer: SARLMemoryBuffer, episode_context: EpisodeContext, ) -> dict[str, Any]: - self.actor.train() - self.critic.train() + self.actor_net.train() + self.critic_net.train() self.autoencoder.train() self.autoencoder.encoder.train() self.autoencoder.decoder.train() self.learn_counter += 1 - self.policy_noise *= self.policy_noise_decay - self.policy_noise = max(self.min_policy_noise, self.policy_noise) + self.policy_noise = self.policy_noise_scheduler.get_value( + episode_context.training_step + ) - self.action_noise *= self.action_noise_decay - self.action_noise = max(self.min_action_noise, self.action_noise) + self.action_noise = self.action_noise_scheduler.get_value( + episode_context.training_step + ) # Convert to tensors using multimodal batch conversion ( @@ -362,16 +468,14 @@ def train( info: dict[str, Any] = {} # Update the Critic - critic_loss_one, critic_loss_two, critic_loss_total = self._update_critic( + critic_info = self._update_critic( observation_tensor, actions_tensor, rewards_tensor, next_observation_tensor, dones_tensor, ) - info["critic_loss_one"] = critic_loss_one - info["critic_loss_two"] = critic_loss_two - info["critic_loss_total"] = critic_loss_total + info.update(critic_info) # Update Autoencoder ae_loss = self._update_autoencoder(observation_tensor.image_state_tensor) @@ -379,19 +483,21 @@ def train( if self.learn_counter % self.policy_update_freq == 0: # Update Actor - actor_loss = self._update_actor(observation_tensor) - info["actor_loss"] = actor_loss + actor_info = self._update_actor(observation_tensor) + info.update(actor_info) # Update target network params # Note: the encoders in target networks are the same of main networks, so I wont update them hlp.soft_update_params( - self.critic.critic.Q1, self.critic_target.critic.Q1, self.tau + self.critic_net.critic.Q1, self.critic_target.critic.Q1, self.tau ) hlp.soft_update_params( - self.critic.critic.Q2, self.critic_target.critic.Q2, self.tau + self.critic_net.critic.Q2, self.critic_target.critic.Q2, self.tau ) - hlp.soft_update_params(self.actor.actor, self.actor_target.actor, self.tau) + hlp.soft_update_params( + self.actor_net.actor, self.actor_target.actor, self.tau + ) # Update intrinsic models if self.intrinsic_on: @@ -520,8 +626,8 @@ def save_models(self, filepath: str, filename: str) -> None: if not os.path.exists(filepath): os.makedirs(filepath) checkpoint = { - "actor": self.actor.state_dict(), - "critic": self.critic.state_dict(), + "actor": self.actor_net.state_dict(), + "critic": self.critic_net.state_dict(), "actor_target": self.actor_target.state_dict(), "critic_target": self.critic_target.state_dict(), "encoder": self.autoencoder.encoder.state_dict(), @@ -540,8 +646,8 @@ def save_models(self, filepath: str, filename: str) -> None: def load_models(self, filepath: str, filename: str) -> None: checkpoint = torch.load(f"{filepath}/{filename}_checkpoint.pth") - self.actor.load_state_dict(checkpoint["actor"]) - self.critic.load_state_dict(checkpoint["critic"]) + self.actor_net.load_state_dict(checkpoint["actor"]) + self.critic_net.load_state_dict(checkpoint["critic"]) self.actor_target.load_state_dict(checkpoint["actor_target"]) self.critic_target.load_state_dict(checkpoint["critic_target"]) diff --git a/cares_reinforcement_learning/algorithm/policy/PALTD3.py b/cares_reinforcement_learning/algorithm/policy/PALTD3.py index f9e99d10..917226ff 100644 --- a/cares_reinforcement_learning/algorithm/policy/PALTD3.py +++ b/cares_reinforcement_learning/algorithm/policy/PALTD3.py @@ -83,6 +83,8 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + with torch.no_grad(): next_actions = self.target_actor_net(next_states) target_noise = self.policy_noise * torch.randn_like(next_actions) @@ -127,10 +129,56 @@ def _update_critic( batch_size = states.shape[0] priorities = np.array([1.0] * batch_size) - info = { - "critic_loss_one": pal_loss_one.item(), - "critic_loss_two": pal_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- TD3-style smoothing diagnostics --- + # Noise diagnostics + # What it tells you: + # - target_noise_abs_mean: effective smoothing magnitude. + # - target_noise_clip_frac high early: noise often clipped (clip too small or noise too large). + target_noise_abs_mean = target_noise.abs().mean().item() + target_noise_clip_frac = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + info["target_noise_abs_mean"] = float(target_noise_abs_mean) + info["target_noise_clip_frac"] = float(target_noise_clip_frac) + + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = pal_loss_one.item() + info["critic_loss_two"] = pal_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities diff --git a/cares_reinforcement_learning/algorithm/policy/PPO.py b/cares_reinforcement_learning/algorithm/policy/PPO.py index 91b8f1e7..0eb4152e 100644 --- a/cares_reinforcement_learning/algorithm/policy/PPO.py +++ b/cares_reinforcement_learning/algorithm/policy/PPO.py @@ -94,7 +94,7 @@ SARLObservationTensors, ) from cares_reinforcement_learning.util.configurations import PPOConfig -from cares_reinforcement_learning.util.helpers import EpsilonScheduler +from cares_reinforcement_learning.util.helpers import LinearScheduler class PPO(SARLAlgorithm[np.ndarray]): @@ -119,13 +119,13 @@ def __init__( self.target_kl = config.target_kl - self.epsilon_scheduler = EpsilonScheduler( - start_epsilon=config.entropy_start, - end_epsilon=config.entropy_end, + self.epsilon_scheduler = LinearScheduler( + start_value=config.entropy_start, + end_value=config.entropy_end, decay_steps=config.entropy_decay, ) # initial entropy coefficient - self.entropy_coef = self.epsilon_scheduler.get_epsilon(0) + self.entropy_coef = self.epsilon_scheduler.get_value(0) self.max_grad_norm = config.max_grad_norm self.min_log_std = config.log_std_bounds[0] @@ -365,7 +365,7 @@ def update_from_batch( batch_size = len(observation_tensor.vector_state_tensor) - self.entropy_coef = self.epsilon_scheduler.get_epsilon( + self.entropy_coef = self.epsilon_scheduler.get_value( episode_context.training_step ) @@ -505,39 +505,93 @@ def update_from_batch( num_critic_mbs += 1 info: dict[str, Any] = {} + + # --------------------------------------------------------- + # Core Losses + # --------------------------------------------------------- + # critic_loss: + # Should generally decrease over time. + # Large spikes -> value instability or bad returns scaling. info["critic_loss"] = sum_critic_loss / max(num_critic_mbs, 1) + + # actor_loss: + # Not directly interpretable in magnitude. + # Watch stability and trend, not absolute value. info["actor_loss"] = sum_actor_loss / max(num_actor_mbs, 1) - # Batch-level stats + # --------------------------------------------------------- + # Advantage & Return Statistics (signal quality) + # --------------------------------------------------------- + # adv_mean should be ~0 if advantages are normalized. + # adv_std too small -> weak learning signal. info["adv_mean"] = float(info_adv_mean.item()) info["adv_std"] = float(info_adv_std.item()) + + # returns stats reflect reward scale. + # Large drift without performance improvement -> value mismatch. info["returns_mean"] = float(info_returns_mean.item()) info["returns_std"] = float(info_returns_std.item()) + + # explained_variance: + # ~1.0 -> value function predicts returns well + # ~0.0 -> value function no better than baseline + # < 0 -> value predictions actively harmful info["explained_variance"] = float(explained_var.item()) + + # entropy_coef controls exploration strength. + # Too small early -> premature convergence. info["entropy_coef"] = self.entropy_coef - # Exploration + # --------------------------------------------------------- + # Policy Exploration (Gaussian std diagnostics) + # --------------------------------------------------------- + # log_std controls action stochasticity. + # If collapsing too early -> exploration dies. + # If very large -> noisy policy. info["log_std_mean"] = float(log_std_mean.item()) info["log_std_min"] = float(log_std_min.item()) info["log_std_max"] = float(log_std_max.item()) - # Update health (averaged over minibatches) + # --------------------------------------------------------- + # Policy Update Health (averaged over minibatches) + # --------------------------------------------------------- if num_actor_mbs > 0: + # Should decrease gradually during training. + # Sudden collapse -> entropy_coef too low. info["entropy"] = sum_entropy / num_actor_mbs + + # Fraction of samples hitting PPO clip. + # ~0.1–0.3 typical. + # ~0 -> updates too small. + # >0.5 -> updates too aggressive. info["clip_frac"] = sum_clip_frac / num_actor_mbs + + # ratio stats (π_new / π_old): + # ratio_mean ~1 is healthy. + # Large std -> unstable updates. info["ratio_mean"] = sum_ratio_mean / num_actor_mbs info["ratio_std"] = sum_ratio_std / num_actor_mbs + # High -> policy frequently at action bounds (tanh saturation). info["action_sat_rate"] = sum_sat_rate / num_actor_mbs + + # u_abs_*: magnitude of pre-squashed action. + # Large values -> pushing into tanh extremes. info["u_abs_mean"] = sum_u_abs_mean / num_actor_mbs info["u_abs_max"] = sum_u_abs_max / num_actor_mbs + # log_ratio diagnostics: + # Large values indicate aggressive policy shifts. info["log_ratio_mean"] = sum_log_ratio_mean / num_actor_mbs info["log_ratio_std"] = sum_log_ratio_std / num_actor_mbs info["log_ratio_max_abs"] = sum_log_ratio_max_abs / num_actor_mbs - # KL (only if enabled) + # --------------------------------------------------------- + # KL Diagnostics (if enabled) + # --------------------------------------------------------- if self.target_kl is not None and num_actor_mbs > 0: + # Measures how far new policy moved from old. + # Should remain near target_kl. info["approx_kl"] = sum_kl / num_actor_mbs info["kl_early_stop"] = int(kl_early_stop) info["max_kl_seen"] = max_kl_seen diff --git a/cares_reinforcement_learning/algorithm/policy/REDQ.py b/cares_reinforcement_learning/algorithm/policy/REDQ.py index fa6e5f59..5d9bf1be 100644 --- a/cares_reinforcement_learning/algorithm/policy/REDQ.py +++ b/cares_reinforcement_learning/algorithm/policy/REDQ.py @@ -124,6 +124,7 @@ def _update_critic( # type: ignore[override] next_states: torch.Tensor, dones: torch.Tensor, ) -> dict[str, Any]: + info: dict[str, Any] = {} # replace=False so that not picking the same idx twice idx = np.random.choice( self.ensemble_size, self.num_sample_critics, replace=False @@ -148,13 +149,24 @@ def _update_critic( # type: ignore[override] q_target = rewards + self.gamma * (1 - dones) * target_q_values - critic_loss_totals = [] + critic_loss_totals: list[float] = [] + critic_td_abs_means: list[float] = [] + critic_q_means: list[float] = [] + + # For ensemble diagnostics (store per-critic outputs on this batch) + q_set: list[torch.Tensor] = [] for critic_net, critic_net_optimiser in zip( self.critic_net.critics, self.ensemble_critic_optimizers ): q_values = critic_net(states, actions) + q_set.append(q_values) + + td = q_values - q_target # signed TD error + critic_td_abs_means.append(td.abs().mean().item()) + critic_q_means.append(q_values.mean().item()) + critic_loss = 0.5 * F.mse_loss(q_values, q_target) critic_net_optimiser.zero_grad() @@ -163,12 +175,72 @@ def _update_critic( # type: ignore[override] critic_loss_totals.append(critic_loss.item()) - critic_loss_total = np.mean(critic_loss_totals) - info = { - "idx": idx, - "critic_loss_total": critic_loss_total, - "critic_loss_totals": critic_loss_totals, - } + with torch.no_grad(): + # Which target critics were sampled (for debugging + reproducibility) + info["idx0"] = int(idx[0]) + info["idx1"] = int(idx[1]) + + # --- Target-side diagnostics (s', pi(s')) --- + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_min_q_mean"] = target_q_values.mean().item() + + # Disagreement between the sampled target critics + target_gap = (target_q_values_one - target_q_values_two).abs() + info["target_q_gap_abs_mean"] = target_gap.mean().item() + info["target_q_gap_abs_p95"] = target_gap.quantile(0.95).item() + + # --- Soft target decomposition (SAC-specific) --- + # min_target_q_mean: the conservative bootstrap value from twin critics (pre-entropy) + # entropy_term_mean: magnitude of entropy regularization in the target (alpha * log_pi is usually negative) + # soft_target_value_mean: the exact term used inside the Bellman target before reward/discount + min_target_q = torch.minimum(target_q_values_one, target_q_values_two) + + # alpha_log_pi is typically negative; entropy_bonus is typically positive + alpha_log_pi = self.alpha * next_log_pi + # this is what gets ADDED to minQ in the target + entropy_bonus = -self.alpha * next_log_pi + + soft_target_value = min_target_q + entropy_bonus # == minQ - alpha*log_pi + + info["target_min_q_mean"] = min_target_q.mean().item() + info["alpha_log_pi_mean"] = alpha_log_pi.mean().item() + info["entropy_bonus_mean"] = entropy_bonus.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # Bellman target scale + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std(unbiased=False).item() + + # --- Critic loss diagnostics --- + info["critic_loss_total"] = float(np.mean(critic_loss_totals)) + info["critic_loss_totals"] = critic_loss_totals # per-critic scalars + + # “Bad apple” detection across critics + info["critic_loss_std_across_critics"] = float(np.std(critic_loss_totals)) + info["critic_td_abs_mean_across_critics"] = float( + np.mean(critic_td_abs_means) + ) + info["critic_td_abs_std_across_critics"] = float( + np.std(critic_td_abs_means) + ) + info["critic_q_mean_across_critics"] = float(np.mean(critic_q_means)) + info["critic_q_std_across_critics"] = float(np.std(critic_q_means)) + + # --- Ensemble disagreement on replay batch (s,a) --- + q_mat = torch.cat(q_set, dim=1) # (B,E) + info["current_ensemble_q_mean"] = q_mat.mean().item() + info["current_ensemble_q_std_mean"] = ( + q_mat.std(dim=1, unbiased=False).mean().item() + ) + # If this grows: see critic divergence / epistemic spread. + + # TD tail risk (more sensitive than mean) + td_mat = q_mat - q_target # broadcast (B,E) - (B,1) + td_abs_max = td_mat.abs().max(dim=1).values # (B,) + info["td_abs_max_mean"] = td_abs_max.mean().item() + info["td_abs_max_p95"] = td_abs_max.quantile(0.95).item() + info["td_abs_max_max"] = td_abs_max.max().item() return info @@ -177,13 +249,40 @@ def _update_actor_alpha( # type: ignore[override] self, states: torch.Tensor, ) -> dict[str, Any]: + info: dict[str, Any] = {} + pi, log_pi, _ = self.actor_net(states) with hlp.evaluating(self.critic_net): q_values = self.critic_net(states, pi) - min_qf_pi = q_values.mean(dim=1) + q_mean = q_values.mean(dim=1) + + actor_loss = ((self.alpha * log_pi) - q_mean).mean() + + # --------------------------------------------------------- + # Stochastic Policy Gradient Strength (∇a [α log π(a|s) − Q(s,a)]) + # --------------------------------------------------------- + # Measures how steep the entropy-regularized critic objective is + # w.r.t. the sampled policy actions. + # + # ~0 early -> critic surface and entropy term nearly flat; + # actor receives weak learning signal. + # + # Very large -> critic or entropy term is very sharp around policy + # actions; can lead to unstable or overly aggressive + # actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=pi, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] - actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean() + with torch.no_grad(): + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() self.actor_net_optimiser.zero_grad() actor_loss.backward() @@ -195,10 +294,44 @@ def _update_actor_alpha( # type: ignore[override] alpha_loss.backward() self.log_alpha_optimizer.step() - info = { - "actor_loss": actor_loss.item(), - "alpha_loss": alpha_loss.item(), - } + with torch.no_grad(): + # --- Policy entropy diagnostics (exploration health) --- + # log_pi more negative -> higher entropy (more stochastic). Less negative -> lower entropy (more deterministic). + info["log_pi_mean"] = log_pi.mean().item() + info["log_pi_std"] = log_pi.std().item() + + # --- Action magnitude/saturation (tanh policies) --- + # High saturation fraction can indicate the policy is slamming bounds; may reduce effective gradients. + info["pi_action_abs_mean"] = pi.abs().mean().item() + info["pi_action_std"] = pi.std().item() + info["pi_action_saturation_frac"] = (pi.abs() > 0.95).float().mean().item() + + # --- On-policy critic signal (REDQ uses ensemble mean) --- + # REDQ actor uses mean over ensemble as value signal + info["q_pi_mean"] = q_mean.mean().item() + info["q_pi_std"] = q_mean.std(unbiased=False).item() + + # --- Ensemble disagreement at policy actions (REDQ analogue of twin-gap) --- + # If this grows, critics disagree on current policy behaviour (instability / epistemic spread). + q_std_across_critics = q_values.std(dim=1, unbiased=False) # (B,) + info["q_pi_ensemble_std_mean"] = q_std_across_critics.mean().item() + info["q_pi_ensemble_std_p95"] = q_std_across_critics.quantile(0.95).item() + + # You can also track dominance extremes if you ever use weighted fusion later + info["q_pi_ensemble_min_mean"] = q_values.min(dim=1).values.mean().item() + info["q_pi_ensemble_max_mean"] = q_values.max(dim=1).values.mean().item() + + # --- Entropy gap (alpha tuning health) --- + # entropy_gap ~ 0 means entropy matches target. + # > 0: entropy too low -> alpha should increase; < 0: entropy too high -> alpha should decrease. + entropy_gap = -(log_pi + self.target_entropy) + info["entropy_gap_mean"] = entropy_gap.mean().item() + + # --- Losses and temperature --- + info["actor_loss"] = actor_loss.item() + info["alpha_loss"] = alpha_loss.item() + info["alpha"] = self.alpha.item() + info["log_alpha"] = self.log_alpha.item() return info @@ -243,7 +376,6 @@ def train( observation_tensor.vector_state_tensor ) info |= actor_info - info["alpha"] = self.alpha.item() if self.learn_counter % self.target_update_freq == 0: # Update ensemble of target critics diff --git a/cares_reinforcement_learning/algorithm/policy/SAC.py b/cares_reinforcement_learning/algorithm/policy/SAC.py index 520e59b6..49c25357 100644 --- a/cares_reinforcement_learning/algorithm/policy/SAC.py +++ b/cares_reinforcement_learning/algorithm/policy/SAC.py @@ -180,6 +180,7 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} with torch.no_grad(): with hlp.evaluating(self.actor_net): @@ -224,11 +225,63 @@ def _update_critic( .flatten() ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Soft target decomposition (SAC-specific) --- + # min_target_q_mean: the conservative bootstrap value from twin critics (pre-entropy) + # entropy_term_mean: magnitude of entropy regularization in the target (alpha * log_pi is usually negative) + # soft_target_value_mean: the exact term used inside the Bellman target before reward/discount + min_target_q = torch.minimum(target_q_values_one, target_q_values_two) + + # alpha_log_pi is typically negative; entropy_bonus is typically positive + alpha_log_pi = self.alpha * next_log_pi + # this is what gets ADDED to minQ in the target + entropy_bonus = -self.alpha * next_log_pi + + soft_target_value = min_target_q + entropy_bonus # == minQ - alpha*log_pi + + info["target_min_q_mean"] = min_target_q.mean().item() + info["alpha_log_pi_mean"] = alpha_log_pi.mean().item() + info["entropy_bonus_mean"] = entropy_bonus.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities @@ -238,15 +291,42 @@ def _update_actor_alpha( states: torch.Tensor, weights: torch.Tensor, # pylint: disable=unused-argument ) -> dict[str, Any]: + info: dict[str, Any] = {} + pi, log_pi, _ = self.actor_net(states) with hlp.evaluating(self.critic_net): - qf1_pi, qf2_pi = self.critic_net(states, pi) + qf_pi_one, qf_pi_two = self.critic_net(states, pi) - min_qf_pi = torch.minimum(qf1_pi, qf2_pi) + min_qf_pi = torch.minimum(qf_pi_one, qf_pi_two) actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean() + # --------------------------------------------------------- + # Stochastic Policy Gradient Strength (∇a [α log π(a|s) − Q(s,a)]) + # --------------------------------------------------------- + # Measures how steep the entropy-regularized critic objective is + # w.r.t. the sampled policy actions. + # + # ~0 early -> critic surface and entropy term nearly flat; + # actor receives weak learning signal. + # + # Very large -> critic or entropy term is very sharp around policy + # actions; can lead to unstable or overly aggressive + # actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=pi, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + with torch.no_grad(): + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() @@ -258,11 +338,37 @@ def _update_actor_alpha( alpha_loss.backward() self.log_alpha_optimizer.step() - info = { - "actor_loss": actor_loss.item(), - "alpha_loss": alpha_loss.item(), - "log_pi": log_pi.mean().item(), - } + with torch.no_grad(): + # --- Policy entropy diagnostics (exploration health) --- + # log_pi more negative -> higher entropy (more stochastic). Less negative -> lower entropy (more deterministic). + info["log_pi_mean"] = log_pi.mean().item() + info["log_pi_std"] = log_pi.std().item() + + # --- Action magnitude/saturation (tanh policies) --- + # High saturation fraction can indicate the policy is slamming bounds; may reduce effective gradients. + info["pi_action_abs_mean"] = pi.abs().mean().item() + info["pi_action_std"] = pi.std().item() + info["pi_action_saturation_frac"] = (pi.abs() > 0.95).float().mean().item() + + # --- On-policy critic signal --- + # min_qf_pi_mean should generally increase as the policy improves (higher value actions under the policy). + info["min_qf_pi_mean"] = min_qf_pi.mean().item() + + # --- Twin critics disagreement at policy actions (more relevant than replay actions) --- + # Large gap here means critics disagree on what the current policy is doing (can destabilize actor updates). + info["qf_pi_gap_abs_mean"] = (qf_pi_one - qf_pi_two).abs().mean().item() + + # --- Entropy gap (alpha tuning health) --- + # entropy_gap ~ 0 means entropy matches target. + # > 0: entropy too low -> alpha should increase; < 0: entropy too high -> alpha should decrease. + entropy_gap = -(log_pi + self.target_entropy) + info["entropy_gap_mean"] = entropy_gap.mean().item() + + # --- Losses and temperature --- + info["actor_loss"] = actor_loss.item() + info["alpha_loss"] = alpha_loss.item() + info["alpha"] = self.alpha.item() + info["log_alpha"] = self.log_alpha.item() return info @@ -296,7 +402,6 @@ def update_from_batch( observation_tensor.vector_state_tensor, weights_tensor ) info |= actor_info - info["alpha"] = self.alpha.item() if self.learn_counter % self.target_update_freq == 0: self.update_target_networks() diff --git a/cares_reinforcement_learning/algorithm/policy/SACAE.py b/cares_reinforcement_learning/algorithm/policy/SACAE.py index 18480d7a..7b9a6139 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACAE.py +++ b/cares_reinforcement_learning/algorithm/policy/SACAE.py @@ -188,6 +188,7 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} with torch.no_grad(): with hlp.evaluating(self.actor_net): @@ -232,23 +233,102 @@ def _update_critic( .flatten() ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Soft target decomposition (SAC-specific) --- + # min_target_q_mean: the conservative bootstrap value from twin critics (pre-entropy) + # entropy_term_mean: magnitude of entropy regularization in the target (alpha * log_pi is usually negative) + # soft_target_value_mean: the exact term used inside the Bellman target before reward/discount + min_target_q = torch.minimum(target_q_values_one, target_q_values_two) + + # alpha_log_pi is typically negative; entropy_bonus is typically positive + alpha_log_pi = self.alpha * next_log_pi + # this is what gets ADDED to minQ in the target + entropy_bonus = -self.alpha * next_log_pi + + soft_target_value = min_target_q + entropy_bonus # == minQ - alpha*log_pi + + info["target_min_q_mean"] = min_target_q.mean().item() + info["alpha_log_pi_mean"] = alpha_log_pi.mean().item() + info["entropy_bonus_mean"] = entropy_bonus.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities def _update_actor_alpha(self, states: SARLObservationTensors) -> dict[str, Any]: + info: dict[str, Any] = {} + pi, log_pi, _ = self.actor_net(states, detach_encoder=True) with hlp.evaluating(self.critic_net): - qf1_pi, qf2_pi = self.critic_net(states, pi, detach_encoder=True) + qf_pi_one, qf_pi_two = self.critic_net(states, pi, detach_encoder=True) - min_qf_pi = torch.minimum(qf1_pi, qf2_pi) + min_qf_pi = torch.minimum(qf_pi_one, qf_pi_two) actor_loss = ((self.alpha * log_pi) - min_qf_pi).mean() + # --------------------------------------------------------- + # Stochastic Policy Gradient Strength (∇a [α log π(a|s) − Q(s,a)]) + # --------------------------------------------------------- + # Measures how steep the entropy-regularized critic objective is + # w.r.t. the sampled policy actions. + # + # ~0 early -> critic surface and entropy term nearly flat; + # actor receives weak learning signal. + # + # Very large -> critic or entropy term is very sharp around policy + # actions; can lead to unstable or overly aggressive + # actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=pi, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + with torch.no_grad(): + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() @@ -260,10 +340,37 @@ def _update_actor_alpha(self, states: SARLObservationTensors) -> dict[str, Any]: alpha_loss.backward() self.log_alpha_optimizer.step() - info = { - "actor_loss": actor_loss.item(), - "alpha_loss": alpha_loss.item(), - } + with torch.no_grad(): + # --- Policy entropy diagnostics (exploration health) --- + # log_pi more negative -> higher entropy (more stochastic). Less negative -> lower entropy (more deterministic). + info["log_pi_mean"] = log_pi.mean().item() + info["log_pi_std"] = log_pi.std().item() + + # --- Action magnitude/saturation (tanh policies) --- + # High saturation fraction can indicate the policy is slamming bounds; may reduce effective gradients. + info["pi_action_abs_mean"] = pi.abs().mean().item() + info["pi_action_std"] = pi.std().item() + info["pi_action_saturation_frac"] = (pi.abs() > 0.95).float().mean().item() + + # --- On-policy critic signal --- + # min_qf_pi_mean should generally increase as the policy improves (higher value actions under the policy). + info["min_qf_pi_mean"] = min_qf_pi.mean().item() + + # --- Twin critics disagreement at policy actions (more relevant than replay actions) --- + # Large gap here means critics disagree on what the current policy is doing (can destabilize actor updates). + info["qf_pi_gap_abs_mean"] = (qf_pi_one - qf_pi_two).abs().mean().item() + + # --- Entropy gap (alpha tuning health) --- + # entropy_gap ~ 0 means entropy matches target. + # > 0: entropy too low -> alpha should increase; < 0: entropy too high -> alpha should decrease. + entropy_gap = -(log_pi + self.target_entropy) + info["entropy_gap_mean"] = entropy_gap.mean().item() + + # --- Losses and temperature --- + info["actor_loss"] = actor_loss.item() + info["alpha_loss"] = alpha_loss.item() + info["alpha"] = self.alpha.item() + info["log_alpha"] = self.log_alpha.item() return info @@ -333,7 +440,6 @@ def train( if self.learn_counter % self.policy_update_freq == 0: actor_info = self._update_actor_alpha(observation_tensor) info |= actor_info - info["alpha"] = self.alpha.item() if self.learn_counter % self.target_update_freq == 0: # Update the target networks - Soft Update diff --git a/cares_reinforcement_learning/algorithm/policy/SACD.py b/cares_reinforcement_learning/algorithm/policy/SACD.py index 027db894..594b2831 100644 --- a/cares_reinforcement_learning/algorithm/policy/SACD.py +++ b/cares_reinforcement_learning/algorithm/policy/SACD.py @@ -163,23 +163,23 @@ def _update_critic( rewards: torch.Tensor, next_states: torch.Tensor, dones: torch.Tensor, - ) -> float: + ) -> dict[str, Any]: + info: dict[str, Any] = {} + with torch.no_grad(): with hlp.evaluating(self.actor_net): _, (action_probs, log_actions_probs), _ = self.actor_net(next_states) - qf1_next_target, qf2_next_target = self.target_critic_net(next_states) + qf1_next, qf2_next = self.target_critic_net(next_states) + min_q_next = torch.minimum(qf1_next, qf2_next) - min_qf_next_target = action_probs * ( - torch.minimum(qf1_next_target, qf2_next_target) - - self.alpha * log_actions_probs - ) + # Soft value: expectation over discrete actions + soft_value = ( + action_probs * (min_q_next - self.alpha * log_actions_probs) + ).sum(dim=1, keepdim=True) - min_qf_next_target = min_qf_next_target.sum(dim=1).unsqueeze(-1) - # TODO: Investigate next_q_value = ( - rewards * self.reward_scale - + (1.0 - dones) * min_qf_next_target * self.gamma + rewards * self.reward_scale + (1.0 - dones) * self.gamma * soft_value ) q_values_one, q_values_two = self.critic_net(states) @@ -195,9 +195,42 @@ def _update_critic( critic_loss_total.backward() self.critic_net_optimiser.step() - return critic_loss_total.item() + with torch.no_grad(): + # --- Target decomposition --- + info["target_min_q_mean"] = min_q_next.mean().item() + info["entropy_bonus_mean"] = (-self.alpha * log_actions_probs).mean().item() + info["soft_value_mean"] = soft_value.mean().item() + + # --- Bellman target scale --- + info["q_target_mean"] = next_q_value.mean().item() + info["q_target_std"] = next_q_value.std(unbiased=False).item() + + # --- Critic value scale --- + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- TD error diagnostics --- + td1 = gathered_q_values_one - next_q_value + td2 = gathered_q_values_two - next_q_value + + td_abs = torch.maximum(td1.abs(), td2.abs()).squeeze(1) + info["td_abs_mean"] = td_abs.mean().item() + info["td_abs_p95"] = td_abs.quantile(0.95).item() + info["td_abs_max"] = td_abs.max().item() + + # --- Loss --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() + + return info + + def _update_actor_alpha(self, states: torch.Tensor) -> dict[str, Any]: + info: dict[str, Any] = {} - def _update_actor_alpha(self, states: torch.Tensor) -> tuple[float, float]: _, (action_probs, log_action_probs), _ = self.actor_net(states) with hlp.evaluating(self.critic_net): @@ -208,7 +241,7 @@ def _update_actor_alpha(self, states: torch.Tensor) -> tuple[float, float]: inside_term = self.alpha * log_action_probs - min_qf_pi actor_loss = (action_probs * inside_term).sum(dim=1).mean() - new_log_action_probs = torch.sum(log_action_probs * action_probs, dim=1) + expected_log_prob = torch.sum(log_action_probs * action_probs, dim=1) self.actor_net_optimiser.zero_grad() actor_loss.backward() @@ -216,14 +249,42 @@ def _update_actor_alpha(self, states: torch.Tensor) -> tuple[float, float]: # update the temperature (alpha) alpha_loss = -( - self.log_alpha * (new_log_action_probs + self.target_entropy).detach() + self.log_alpha * (expected_log_prob + self.target_entropy).detach() ).mean() self.log_alpha_optimizer.zero_grad() alpha_loss.backward() self.log_alpha_optimizer.step() - return actor_loss.item(), alpha_loss.item() + with torch.no_grad(): + # --- Policy distribution health --- + entropy = -(action_probs * log_action_probs).sum(dim=1) + + info["entropy_mean"] = entropy.mean().item() + info["entropy_std"] = entropy.std(unbiased=False).item() + + # Action distribution sharpness + max_prob = action_probs.max(dim=1).values + info["max_action_prob_mean"] = max_prob.mean().item() + info["max_action_prob_p95"] = max_prob.quantile(0.95).item() + + info["policy_prob_std_mean"] = action_probs.std(dim=1).mean().item() + + # --- Q signal to actor --- + info["min_q_pi_mean"] = min_qf_pi.mean().item() + info["min_q_pi_std"] = min_qf_pi.std(unbiased=False).item() + + # --- Entropy calibration --- + entropy_gap = -(expected_log_prob + self.target_entropy) + info["entropy_gap_mean"] = entropy_gap.mean().item() + + # --- Losses & temperature --- + info["actor_loss"] = actor_loss.item() + info["alpha_loss"] = alpha_loss.item() + info["alpha"] = self.alpha.item() + info["log_alpha"] = self.log_alpha.item() + + return info def train( self, @@ -251,23 +312,21 @@ def train( info = {} # Update the Critic - critic_loss_total = self._update_critic( + critic_info = self._update_critic( observation_tensor.vector_state_tensor, actions_tensor, rewards_tensor, next_observation_tensor.vector_state_tensor, dones_tensor, ) - info["critic_loss"] = critic_loss_total + info.update(critic_info) if self.learn_counter % self.policy_update_freq == 0: # Update the Actor and Alpha - actor_loss, alpha_loss = self._update_actor_alpha( + actor_info = self._update_actor_alpha( observation_tensor.vector_state_tensor ) - info["actor_loss"] = actor_loss - info["alpha_loss"] = alpha_loss - info["alpha"] = self.alpha.item() + info.update(actor_info) if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) diff --git a/cares_reinforcement_learning/algorithm/policy/SDAR.py b/cares_reinforcement_learning/algorithm/policy/SDAR.py index 87766191..70b79253 100644 --- a/cares_reinforcement_learning/algorithm/policy/SDAR.py +++ b/cares_reinforcement_learning/algorithm/policy/SDAR.py @@ -201,7 +201,7 @@ def _update_critic( # type: ignore[override] dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: - + info: dict[str, Any] = {} with torch.no_grad(): with hlp.evaluating(self.actor_net): next_actions, next_log_pi, *_ = self.actor_net( @@ -247,11 +247,63 @@ def _update_critic( # type: ignore[override] .flatten() ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Soft target decomposition (SAC-specific) --- + # min_target_q_mean: the conservative bootstrap value from twin critics (pre-entropy) + # entropy_term_mean: magnitude of entropy regularization in the target (alpha * log_pi is usually negative) + # soft_target_value_mean: the exact term used inside the Bellman target before reward/discount + min_target_q = torch.minimum(target_q_values_one, target_q_values_two) + + # alpha_log_pi is typically negative; entropy_bonus is typically positive + alpha_log_pi = self.alpha * next_log_pi + # this is what gets ADDED to minQ in the target + entropy_bonus = -self.alpha * next_log_pi + + soft_target_value = min_target_q + entropy_bonus # == minQ - alpha*log_pi + + info["target_min_q_mean"] = min_target_q.mean().item() + info["alpha_log_pi_mean"] = alpha_log_pi.mean().item() + info["entropy_bonus_mean"] = entropy_bonus.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities @@ -262,8 +314,10 @@ def _update_actor_alpha( # type: ignore[override] prev_actions: torch.Tensor, weights: torch.Tensor, # pylint: disable=unused-argument ) -> dict[str, Any]: + info: dict[str, Any] = {} + ( - sample_action, + pi, log_pi, _, act_probs, @@ -272,12 +326,37 @@ def _update_actor_alpha( # type: ignore[override] ) = self.actor_net(states, prev_actions, force_act=False) with hlp.evaluating(self.critic_net): - qf1_pi, qf2_pi = self.critic_net(states, sample_action) + qf_pi_one, qf_pi_two = self.critic_net(states, pi) - min_qf_pi = torch.minimum(qf1_pi, qf2_pi) + min_qf_pi = torch.minimum(qf_pi_one, qf_pi_two) actor_loss = ((self.alpha * log_pi) + (self.beta * log_beta) - min_qf_pi).mean() + # --------------------------------------------------------- + # Stochastic Policy Gradient Strength (∇a [α log π(a|s) − Q(s,a)]) + # --------------------------------------------------------- + # Measures how steep the entropy-regularized critic objective is + # w.r.t. the sampled policy actions. + # + # ~0 early -> critic surface and entropy term nearly flat; + # actor receives weak learning signal. + # + # Very large -> critic or entropy term is very sharp around policy + # actions; can lead to unstable or overly aggressive + # actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=pi, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + with torch.no_grad(): + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() @@ -296,15 +375,47 @@ def _update_actor_alpha( # type: ignore[override] beta_loss.backward() self.log_beta_optimizer.step() - info = { - "actor_loss": actor_loss.item(), - "alpha_loss": alpha_loss.item(), - "beta_loss": beta_loss.item(), - "log_pi": log_pi.mean().item(), - "log_beta": log_beta.mean().item(), - "act_prob_mean": act_probs.mean().item(), - "b_mean": binary_mask.mean().item(), - } + with torch.no_grad(): + # --- SDAR specific diagnostics --- + # act_probs: the Bernoulli probabilities for selecting new actions vs repeating old ones. + # binary_mask: the actual sampled 0/1 mask for repetition vs new action. + # log_beta: the log of the temperature for the β regularization term; should adapt to balance repetition vs new action selection. + info["act_prob_mean"] = act_probs.mean().item() + info["log_beta_mean"] = log_beta.mean().item() + info["binary_mask_mean"] = binary_mask.mean().item() + info["beta"] = self.beta.item() + info["log_beta"] = log_beta.mean().item() + + # --- Policy entropy diagnostics (exploration health) --- + # log_pi more negative -> higher entropy (more stochastic). Less negative -> lower entropy (more deterministic). + info["log_pi_mean"] = log_pi.mean().item() + info["log_pi_std"] = log_pi.std().item() + + # --- Action magnitude/saturation (tanh policies) --- + # High saturation fraction can indicate the policy is slamming bounds; may reduce effective gradients. + info["pi_action_abs_mean"] = pi.abs().mean().item() + info["pi_action_std"] = pi.std().item() + info["pi_action_saturation_frac"] = (pi.abs() > 0.95).float().mean().item() + + # --- On-policy critic signal --- + # min_qf_pi_mean should generally increase as the policy improves (higher value actions under the policy). + info["min_qf_pi_mean"] = min_qf_pi.mean().item() + + # --- Twin critics disagreement at policy actions (more relevant than replay actions) --- + # Large gap here means critics disagree on what the current policy is doing (can destabilize actor updates). + info["qf_pi_gap_abs_mean"] = (qf_pi_one - qf_pi_two).abs().mean().item() + + # --- Entropy gap (alpha tuning health) --- + # entropy_gap ~ 0 means entropy matches target. + # > 0: entropy too low -> alpha should increase; < 0: entropy too high -> alpha should decrease. + entropy_gap = -(log_pi + self.target_entropy) + info["entropy_gap_mean"] = entropy_gap.mean().item() + + # --- Losses and temperature --- + info["actor_loss"] = actor_loss.item() + info["alpha_loss"] = alpha_loss.item() + info["alpha"] = self.alpha.item() + info["log_alpha"] = self.log_alpha.item() return info @@ -361,8 +472,6 @@ def train( weights_tensor, ) info |= actor_info - info["alpha"] = self.alpha.item() - info["beta"] = self.beta.item() if self.learn_counter % self.target_update_freq == 0: hlp.soft_update_params(self.critic_net, self.target_critic_net, self.tau) diff --git a/cares_reinforcement_learning/algorithm/policy/TD3.py b/cares_reinforcement_learning/algorithm/policy/TD3.py index 1fce58a5..698cf55a 100644 --- a/cares_reinforcement_learning/algorithm/policy/TD3.py +++ b/cares_reinforcement_learning/algorithm/policy/TD3.py @@ -70,8 +70,8 @@ SARLObservation, SARLObservationTensors, ) - from cares_reinforcement_learning.util.configurations import TD3Config +from cares_reinforcement_learning.util.helpers import ExponentialScheduler class TD3(SARLAlgorithm[np.ndarray]): @@ -103,16 +103,21 @@ def __init__( self.min_priority = config.min_priority # Policy noise - self.min_policy_noise = config.min_policy_noise - self.policy_noise = config.policy_noise - self.policy_noise_decay = config.policy_noise_decay - self.policy_noise_clip = config.policy_noise_clip + self.policy_noise_scheduler = ExponentialScheduler( + start_value=config.policy_noise_start, + end_value=config.policy_noise_end, + decay_steps=config.policy_noise_decay, + ) + self.policy_noise = self.policy_noise_scheduler.get_value(0) # Action noise - self.min_action_noise = config.min_action_noise - self.action_noise = config.action_noise - self.action_noise_decay = config.action_noise_decay + self.action_noise_scheduler = ExponentialScheduler( + start_value=config.action_noise_start, + end_value=config.action_noise_end, + decay_steps=config.action_noise_decay, + ) + self.action_noise = self.action_noise_scheduler.get_value(0) self.learn_counter = 0 self.policy_update_freq = config.policy_update_freq @@ -175,6 +180,7 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} with torch.no_grad(): with hlp.evaluating(self.actor_net): next_actions = self.target_actor_net(next_states) @@ -222,11 +228,58 @@ def _update_critic( .flatten() ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- TD3-style smoothing diagnostics --- + # Noise diagnostics + # What it tells you: + # - target_noise_abs_mean: effective smoothing magnitude. + # - target_noise_clip_frac high early: noise often clipped (clip too small or noise too large). + target_noise_abs_mean = target_noise.abs().mean().item() + target_noise_clip_frac = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + info["target_noise_abs_mean"] = float(target_noise_abs_mean) + info["target_noise_clip_frac"] = float(target_noise_clip_frac) + + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() + return info, priorities # Weights is set for methods like MAPERTD3 that use weights in the actor update @@ -235,6 +288,8 @@ def _update_actor( states: torch.Tensor, weights: torch.Tensor, # pylint: disable=unused-argument ) -> dict[str, Any]: + info: dict[str, Any] = {} + actions = self.actor_net(states) with hlp.evaluating(self.critic_net): @@ -242,15 +297,49 @@ def _update_actor( actor_loss = -actor_q_values.mean() + # --------------------------------------------------------- + # Deterministic Policy Gradient Strength (∇a Q(s,a)) + # --------------------------------------------------------- + # Measures how steep the critic surface is w.r.t. actions. + # ~0 early -> critic flat, actor receives no learning signal. + # Very large -> critic overly sharp, can cause unstable actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=actions, + retain_graph=True, # because we do backward(actor_loss) next + create_graph=False, # diagnostic only + allow_unused=False, + )[0] + with torch.no_grad(): + # - ~0 early: critic surface flat around actor actions (weak learning signal) + # - very large: critic surface sharp -> unstable / exploitative actor updates + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() - actor_info = { - "actor_loss": actor_loss.item(), - } + with torch.no_grad(): + # Policy Action Health (tanh policies in [-1, 1]) + # pi_action_saturation_frac: + # High values (>0.8 early) often mean the actor is slamming bounds, + # reducing effective gradient flow through tanh. + info["pi_action_mean"] = actions.mean().item() + info["pi_action_std"] = actions.std().item() + info["pi_action_abs_mean"] = actions.abs().mean().item() + info["pi_action_saturation_frac"] = ( + (actions.abs() > 0.95).float().mean().item() + ) - return actor_info + # actor_q_mean should generally increase over training. + # actor_q_std large + unstable may indicate critic inconsistency. + info["actor_loss"] = actor_loss.item() + info["actor_q_mean"] = actor_q_values.mean().item() + info["actor_q_std"] = actor_q_values.std().item() + + return info def update_from_batch( self, @@ -262,17 +351,20 @@ def update_from_batch( dones_tensor: torch.Tensor, weights_tensor: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + self.learn_counter += 1 - # TODO replace with training_step based approach to avoid having to save this value - self.policy_noise *= self.policy_noise_decay - self.policy_noise = max(self.min_policy_noise, self.policy_noise) + self.policy_noise = self.policy_noise_scheduler.get_value( + episode_context.training_step + ) - # TODO replace with training_step based approach to avoid having to save this value - self.action_noise *= self.action_noise_decay - self.action_noise = max(self.min_action_noise, self.action_noise) + self.action_noise = self.action_noise_scheduler.get_value( + episode_context.training_step + ) - info: dict[str, Any] = {} + info["policy_noise"] = float(self.policy_noise) + info["action_noise"] = float(self.action_noise) # Update the Critic critic_info, priorities = self._update_critic( diff --git a/cares_reinforcement_learning/algorithm/policy/TD3AE.py b/cares_reinforcement_learning/algorithm/policy/TD3AE.py index b695f3f6..c4921cf1 100644 --- a/cares_reinforcement_learning/algorithm/policy/TD3AE.py +++ b/cares_reinforcement_learning/algorithm/policy/TD3AE.py @@ -79,6 +79,7 @@ SARLObservationTensors, ) from cares_reinforcement_learning.util.configurations import TD3AEConfig +from cares_reinforcement_learning.util.helpers import ExponentialScheduler class TD3AE(SARLAlgorithm[np.ndarray]): @@ -120,16 +121,21 @@ def __init__( self.min_priority = config.min_priority # Policy noise - self.min_policy_noise = config.min_policy_noise - self.policy_noise = config.policy_noise - self.policy_noise_decay = config.policy_noise_decay - self.policy_noise_clip = config.policy_noise_clip + self.policy_noise_scheduler = ExponentialScheduler( + start_value=config.policy_noise_start, + end_value=config.policy_noise_end, + decay_steps=config.policy_noise_decay, + ) + self.policy_noise = self.policy_noise_scheduler.get_value(0) # Action noise - self.min_action_noise = config.min_action_noise - self.action_noise = config.action_noise - self.action_noise_decay = config.action_noise_decay + self.action_noise_scheduler = ExponentialScheduler( + start_value=config.action_noise_start, + end_value=config.action_noise_end, + decay_steps=config.action_noise_decay, + ) + self.action_noise = self.action_noise_scheduler.get_value(0) self.learn_counter = 0 self.policy_update_freq = config.policy_update_freq @@ -188,6 +194,8 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + with torch.no_grad(): next_actions = self.target_actor_net(next_states) @@ -235,15 +243,63 @@ def _update_critic( .flatten() ) - info = { - "critic_loss_one": critic_loss_one.item(), - "critic_loss_two": critic_loss_two.item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- TD3-style smoothing diagnostics --- + # Noise diagnostics + # What it tells you: + # - target_noise_abs_mean: effective smoothing magnitude. + # - target_noise_clip_frac high early: noise often clipped (clip too small or noise too large). + target_noise_abs_mean = target_noise.abs().mean().item() + target_noise_clip_frac = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + info["target_noise_abs_mean"] = float(target_noise_abs_mean) + info["target_noise_clip_frac"] = float(target_noise_clip_frac) + + # --- Twin critic disagreement (stability/uncertainty) --- + # If this grows over training, critics are diverging / becoming inconsistent. + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- Target critics disagreement (target stability) --- + # Large/unstable gap here often means target critics are drifting or policy is visiting OOD actions. + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_twin_gap_abs_mean"] = ( + (target_q_values_one - target_q_values_two).abs().mean().item() + ) + + # --- Bellman target scale (reward scaling / discount sanity) --- + # If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability. + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std().item() + + # --- TD error diagnostics (Bellman fit quality) --- + # td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability. + td1 = q_values_one - q_target # signed + td2 = q_values_two - q_target # signed + + info["td1_mean"] = td1.mean().item() + info["td1_std"] = td1.std().item() + info["td1_abs_mean"] = td1.abs().mean().item() + + info["td2_mean"] = td2.mean().item() + info["td2_std"] = td2.std().item() + info["td2_abs_mean"] = td2.abs().mean().item() + + # --- Losses (optimization progress; less diagnostic than TD/twin gaps) --- + info["critic_loss_one"] = critic_loss_one.item() + info["critic_loss_two"] = critic_loss_two.item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities def _update_actor(self, states: SARLObservationTensors) -> dict[str, Any]: + info: dict[str, Any] = {} + actions = self.actor_net(states, detach_encoder=True) with hlp.evaluating(self.critic_net): @@ -251,13 +307,48 @@ def _update_actor(self, states: SARLObservationTensors) -> dict[str, Any]: actor_loss = -actor_q_values.mean() + # --------------------------------------------------------- + # Deterministic Policy Gradient Strength (∇a Q(s,a)) + # --------------------------------------------------------- + # Measures how steep the critic surface is w.r.t. actions. + # ~0 early -> critic flat, actor receives no learning signal. + # Very large -> critic overly sharp, can cause unstable actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=actions, + retain_graph=True, # because we do backward(actor_loss) next + create_graph=False, # diagnostic only + allow_unused=False, + )[0] + with torch.no_grad(): + # - ~0 early: critic surface flat around actor actions (weak learning signal) + # - very large: critic surface sharp -> unstable / exploitative actor updates + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() - info = { - "actor_loss": actor_loss.item(), - } + with torch.no_grad(): + # Policy Action Health (tanh policies in [-1, 1]) + # pi_action_saturation_frac: + # High values (>0.8 early) often mean the actor is slamming bounds, + # reducing effective gradient flow through tanh. + info["pi_action_mean"] = actions.mean().item() + info["pi_action_std"] = actions.std().item() + info["pi_action_abs_mean"] = actions.abs().mean().item() + info["pi_action_saturation_frac"] = ( + (actions.abs() > 0.95).float().mean().item() + ) + + # actor_q_mean should generally increase over training. + # actor_q_std large + unstable may indicate critic inconsistency. + info["actor_loss"] = actor_loss.item() + info["actor_q_mean"] = actor_q_values.mean().item() + info["actor_q_std"] = actor_q_values.std().item() + return info def _update_autoencoder(self, states: torch.Tensor) -> dict[str, Any]: @@ -288,11 +379,13 @@ def train( ) -> dict[str, Any]: self.learn_counter += 1 - self.policy_noise *= self.policy_noise_decay - self.policy_noise = max(self.min_policy_noise, self.policy_noise) + self.policy_noise = self.policy_noise_scheduler.get_value( + episode_context.training_step + ) - self.action_noise *= self.action_noise_decay - self.action_noise = max(self.min_action_noise, self.action_noise) + self.action_noise = self.action_noise_scheduler.get_value( + episode_context.training_step + ) # Sample and convert to tensors using multimodal sampling ( diff --git a/cares_reinforcement_learning/algorithm/policy/TD7.py b/cares_reinforcement_learning/algorithm/policy/TD7.py index 09ad3776..8a69734d 100644 --- a/cares_reinforcement_learning/algorithm/policy/TD7.py +++ b/cares_reinforcement_learning/algorithm/policy/TD7.py @@ -117,6 +117,7 @@ from cares_reinforcement_learning.types.episode import EpisodeContext from cares_reinforcement_learning.types.observation import SARLObservation from cares_reinforcement_learning.util.configurations import TD7Config +from cares_reinforcement_learning.util.helpers import ExponentialScheduler class TD7(SARLAlgorithm[np.ndarray]): @@ -175,16 +176,21 @@ def __init__( self.min_priority = config.min_priority # Policy noise - self.min_policy_noise = config.min_policy_noise - self.policy_noise = config.policy_noise - self.policy_noise_decay = config.policy_noise_decay - self.policy_noise_clip = config.policy_noise_clip + self.policy_noise_scheduler = ExponentialScheduler( + start_value=config.policy_noise_start, + end_value=config.policy_noise_end, + decay_steps=config.policy_noise_decay, + ) + self.policy_noise = self.policy_noise_scheduler.get_value(0) # Action noise - self.min_action_noise = config.min_action_noise - self.action_noise = config.action_noise - self.action_noise_decay = config.action_noise_decay + self.action_noise_scheduler = ExponentialScheduler( + start_value=config.action_noise_start, + end_value=config.action_noise_end, + decay_steps=config.action_noise_decay, + ) + self.action_noise = self.action_noise_scheduler.get_value(0) self.learn_counter = 0 self.policy_update_freq = config.policy_update_freq @@ -264,6 +270,7 @@ def _calculate_value(self, state: SARLObservation, action: np.ndarray) -> float: def _update_encoder( self, states: torch.Tensor, actions: torch.Tensor, next_states: torch.Tensor ) -> dict[str, Any]: + info: dict[str, Any] = {} with torch.no_grad(): next_zs = self.encoder_net.zs(next_states) @@ -277,7 +284,29 @@ def _update_encoder( encoder_loss.backward() self.encoder_net_optimiser.step() - info = {"encoder_loss": encoder_loss.item()} + with torch.no_grad(): + # --- SALE / representation health --- + # If these collapse to ~0 (or explode), your representations are unhealthy even if RL loss looks fine. + info["encoder_loss"] = encoder_loss.item() + + info["zs_abs_mean"] = zs.abs().mean().item() + info["zs_std"] = zs.std(unbiased=False).item() + info["pred_zs_abs_mean"] = pred_zs.abs().mean().item() + info["pred_zs_std"] = pred_zs.std(unbiased=False).item() + + # Cosine similarity: are we predicting the right *direction* in latent space? + eps = 1e-12 + cos = (pred_zs * next_zs).sum(dim=1) / ( + pred_zs.norm(dim=1) * next_zs.norm(dim=1) + eps + ) + info["sale_cos_mean"] = cos.mean().item() + info["sale_cos_p05"] = cos.quantile(0.05).item() + + # Tail risk: max per-sample latent prediction error (spots “bad transitions” / distribution shifts) + per_sample_mse = (pred_zs - next_zs).pow(2).mean(dim=1) + info["sale_mse_mean"] = per_sample_mse.mean().item() + info["sale_mse_p95"] = per_sample_mse.quantile(0.95).item() + return info def _update_critic( @@ -290,6 +319,8 @@ def _update_critic( weights: torch.Tensor, ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} + with torch.no_grad(): fixed_target_zs = self.target_fixed_encoder_net.zs(next_states) @@ -313,10 +344,14 @@ def _update_critic( target_q_values = torch.minimum(target_q_values_one, target_q_values_two) - target_q_values = target_q_values.clamp(self.min_target, self.max_target) + # TD7 value clipping + target_q_values_clipped = target_q_values.clamp( + self.min_target, self.max_target + ) - q_target = rewards + self.gamma * (1 - dones) * target_q_values + q_target = rewards + self.gamma * (1 - dones) * target_q_values_clipped + # tracked range (used to set next min_target/max_target later) self.max = max(self.max, float(q_target.max())) self.min = min(self.min, float(q_target.min())) @@ -358,11 +393,58 @@ def _update_critic( .flatten() ) - info = { - "critic_loss_one": huber_loss_one.mean().item(), - "critic_loss_two": huber_loss_two.mean().item(), - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # --- Target policy smoothing diagnostics (TD3/TD7) --- + info["target_noise_abs_mean"] = target_noise.abs().mean().item() + info["target_noise_clip_frac"] = ( + (target_noise.abs() >= self.policy_noise_clip).float().mean().item() + ) + + # --- SALE feature inputs (sanity) --- + info["fixed_zs_abs_mean"] = fixed_zs.abs().mean().item() + info["fixed_zsa_abs_mean"] = fixed_zsa.abs().mean().item() + + # --- Target critics + clipping diagnostics (TD7-specific) --- + info["target_q1_mean"] = target_q_values_one.mean().item() + info["target_q2_mean"] = target_q_values_two.mean().item() + info["target_q_min_mean"] = target_q_values.mean().item() + + # How active is clipping? + # High clamp_frac means you’re *often* forcing targets into a limited band. + hit_min = (target_q_values <= self.min_target).float() + hit_max = (target_q_values >= self.max_target).float() + info["clip_hit_min_frac"] = hit_min.mean().item() + info["clip_hit_max_frac"] = hit_max.mean().item() + info["clip_any_frac"] = torch.maximum(hit_min, hit_max).mean().item() + + # How much is clipping changing the target values? + clip_delta = (target_q_values - target_q_values_clipped).abs() + info["clip_delta_abs_mean"] = clip_delta.mean().item() + info["clip_delta_abs_p95"] = clip_delta.quantile(0.95).item() + + # Track the clip band you are using right now + info["clip_min_target"] = float(self.min_target) + info["clip_max_target"] = float(self.max_target) + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std(unbiased=False).item() + + # --- Current critics health --- + info["q1_mean"] = q_values_one.mean().item() + info["q2_mean"] = q_values_two.mean().item() + info["q_twin_gap_abs_mean"] = ( + (q_values_one - q_values_two).abs().mean().item() + ) + + # --- TD error diagnostics (fit quality + tail risk) --- + td_abs_max = torch.maximum(td_error_one, td_error_two).squeeze(1) # (B,) + info["td_abs_mean"] = td_abs_max.mean().item() + info["td_abs_p95"] = td_abs_max.quantile(0.95).item() + info["td_abs_max"] = td_abs_max.max().item() + + # --- Losses (LAP/Huber) --- + info["critic_loss_one"] = huber_loss_one.mean().item() + info["critic_loss_two"] = huber_loss_two.mean().item() + info["critic_loss_total"] = critic_loss_total.item() return info, priorities @@ -372,6 +454,8 @@ def _update_actor( states: torch.Tensor, weights: torch.Tensor, # pylint: disable=unused-argument ) -> dict[str, Any]: + info: dict[str, Any] = {} + with hlp.evaluating(self.encoder_net): fixed_zs = self.fixed_encoder_net.zs(states) @@ -387,18 +471,54 @@ def _update_actor( actor_q_values = torch.cat([actor_q_values_one, actor_q_values_two], dim=1) actor_loss = -actor_q_values.mean() + # --------------------------------------------------------- + # Deterministic Policy Gradient Strength (∇a Q(s,a)) [TD7 actor] + # --------------------------------------------------------- + # Same interpretation as TD3, but note Q depends on SALE features too. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=actions, + retain_graph=True, # because we do backward(actor_loss) next + create_graph=False, # diagnostic only + allow_unused=False, + )[0] + with torch.no_grad(): + # - ~0 early: critic surface flat around actor actions (weak learning signal) + # - very large: critic surface sharp -> unstable / exploitative actor updates + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() - actor_info = { - "actor_loss": actor_loss.item(), - } + with torch.no_grad(): + # --- Action health --- + info["pi_action_mean"] = actions.mean().item() + info["pi_action_std"] = actions.std(unbiased=False).item() + info["pi_action_abs_mean"] = actions.abs().mean().item() + info["pi_action_saturation_frac"] = ( + (actions.abs() > 0.95).float().mean().item() + ) + + # --- SALE inputs to policy (representation scale sanity) --- + info["actor_zs_abs_mean"] = fixed_zs.abs().mean().item() + info["actor_zs_std"] = fixed_zs.std(unbiased=False).item() - return actor_info + # --- On-policy critic signal --- + info["actor_loss"] = actor_loss.item() + info["actor_q_mean"] = actor_q_values.mean().item() + info["actor_q_std"] = actor_q_values.std(unbiased=False).item() + info["qf_pi_gap_abs_mean"] = ( + (actor_q_values_one - actor_q_values_two).abs().mean().item() + ) - def update_networks( + return info + + def update_from_batch( self, + episode_context: EpisodeContext, memory: SARLMemoryBuffer, indices: np.ndarray, states_tensor: torch.Tensor, @@ -411,6 +531,17 @@ def update_networks( info: dict[str, Any] = {} + self.policy_noise = self.policy_noise_scheduler.get_value( + episode_context.training_step + ) + + self.action_noise = self.action_noise_scheduler.get_value( + episode_context.training_step + ) + + info["policy_noise"] = float(self.policy_noise) + info["action_noise"] = float(self.action_noise) + encoder_info = self._update_encoder( states_tensor, actions_tensor, next_states_tensor ) @@ -454,7 +585,6 @@ def update_networks( return info - # TODO use training_step with decay rates def _train_policy( self, memory_buffer: SARLMemoryBuffer, @@ -462,14 +592,6 @@ def _train_policy( ) -> dict[str, Any]: self.learn_counter += 1 - # TODO replace with training_step based approach to avoid having to save this value - self.policy_noise *= self.policy_noise_decay - self.policy_noise = max(self.min_policy_noise, self.policy_noise) - - # TODO replace with training_step based approach to avoid having to save this value - self.action_noise *= self.action_noise_decay - self.action_noise = max(self.min_action_noise, self.action_noise) - # Use the helper to sample and prepare tensors in one step ( observation_tensor, @@ -489,7 +611,8 @@ def _train_policy( per_weight_normalisation=self.per_weight_normalisation, ) - info = self.update_networks( + info = self.update_from_batch( + episode_context, memory_buffer, indices, observation_tensor.vector_state_tensor, @@ -509,10 +632,20 @@ def _train_and_reset( ) -> dict[str, Any]: info: dict[str, Any] = {} + # Log pre-reset state (useful for debugging / plotting) + info["pre_reset_best_min_return"] = float(self.best_min_return) + info["pre_reset_max_eps_before_update"] = int(self.max_eps_before_update) + info["pre_reset_timesteps_since_update"] = int(self.timesteps_since_update) + info["pre_reset_eps_since_update"] = int(self.eps_since_update) + + # Track whether we hit the checkpointing regime switch during this training burst + checkpoint_regime_switched = False + for _ in range(self.timesteps_since_update): if self.learn_counter == self.steps_before_checkpointing: self.best_min_return *= self.reset_weight self.max_eps_before_update = self.max_eps_checkpointing + checkpoint_regime_switched = True info = self._train_policy(memory_buffer, episode_context) @@ -520,6 +653,11 @@ def _train_and_reset( self.timesteps_since_update = 0 self.min_return = 1e8 + # Log post-reset state + event flag + info["checkpoint_regime_switched"] = float(checkpoint_regime_switched) + info["post_reset_best_min_return"] = float(self.best_min_return) + info["post_reset_max_eps_before_update"] = int(self.max_eps_before_update) + return info def train( @@ -536,20 +674,56 @@ def train( if not episode_done: return info + # ------------------------- + # TD7 gate bookkeeping + # ------------------------- self.eps_since_update += 1 self.timesteps_since_update += episode_steps - self.min_return = min(self.min_return, episode_return) - if self.min_return < self.best_min_return: - info = self._train_and_reset(memory_buffer, episode_context) - - elif self.eps_since_update == self.max_eps_before_update: + # ------------------------- + # Episode-level logging (always) + # ------------------------- + info["episode_return"] = float(episode_return) + info["episode_steps"] = int(episode_steps) + + # "Worst-case since last update" tracking (TD7's gate signal) + info["min_return_window"] = float(self.min_return) + info["best_min_return"] = float(self.best_min_return) + + # Cadence / counters + info["eps_since_update"] = int(self.eps_since_update) + info["timesteps_since_update"] = int(self.timesteps_since_update) + info["max_eps_before_update"] = int(self.max_eps_before_update) + info["learn_counter"] = int(self.learn_counter) + + # Decision flags (exactly one of these becomes 1.0 on episode end) + reset_triggered = self.min_return < self.best_min_return + accept_triggered = ( + self.eps_since_update == self.max_eps_before_update + ) and not reset_triggered + + info["reset_triggered"] = float(reset_triggered) + info["accept_triggered"] = float(accept_triggered) + + # How far from the gate are we? + # Negative reset_margin => will reset; positive means "safe" above best_min_return. + info["reset_margin"] = float(self.min_return - self.best_min_return) + + # ------------------------- + # TD7 decisions + # ------------------------- + if reset_triggered: + train_info = self._train_and_reset(memory_buffer, episode_context) + info.update(train_info) + + elif accept_triggered: self.best_min_return = self.min_return self.checkpoint_actor.load_state_dict(self.actor_net.state_dict()) self.checkpoint_encoder.load_state_dict(self.encoder_net.state_dict()) - info = self._train_and_reset(memory_buffer, episode_context) + train_info = self._train_and_reset(memory_buffer, episode_context) + info.update(train_info) return info diff --git a/cares_reinforcement_learning/algorithm/policy/TQC.py b/cares_reinforcement_learning/algorithm/policy/TQC.py index cbf9de90..70c9af3d 100644 --- a/cares_reinforcement_learning/algorithm/policy/TQC.py +++ b/cares_reinforcement_learning/algorithm/policy/TQC.py @@ -123,7 +123,10 @@ def _update_critic( dones: torch.Tensor, weights: torch.Tensor, # pylint: disable=unused-argument ) -> tuple[dict[str, Any], np.ndarray]: + info: dict[str, Any] = {} batch_size = len(states) + + drop_q_range = self.quantiles_total - self.top_quantiles_to_drop with torch.no_grad(): with hlp.evaluating(self.actor_net): next_actions, next_log_pi, _ = self.actor_net(next_states) @@ -134,9 +137,7 @@ def _update_critic( sorted_target_q_values, _ = torch.sort( target_q_values.reshape(batch_size, -1) ) - top_quantile_target_q_values = sorted_target_q_values[ - :, : self.quantiles_total - self.top_quantiles_to_drop - ] + top_quantile_target_q_values = sorted_target_q_values[:, :drop_q_range] # compute target q_target = rewards + (1 - dones) * self.gamma * ( @@ -147,12 +148,11 @@ def _update_critic( # Compute td_error for PER sorted_q_values, _ = torch.sort(q_values.reshape(batch_size, -1)) - top_quantile_q_values = sorted_q_values[ - :, : self.quantiles_total - self.top_quantiles_to_drop - ] + top_quantile_q_values = sorted_q_values[:, :drop_q_range] td_errors = top_quantile_q_values - q_target - td_error = td_errors.abs().mean(dim=1) # mean over quantiles + td_errors_abs = td_errors.abs() + td_error = td_errors_abs.mean(dim=1) # mean over quantiles critic_loss_total = hlp.calculate_quantile_huber_loss( q_values, @@ -177,9 +177,100 @@ def _update_critic( .flatten() ) - info = { - "critic_loss_total": critic_loss_total.item(), - } + with torch.no_grad(): + # ---- Loss ---- + info["critic_loss_total"] = critic_loss_total.item() + + # ---- Target decomposition (SAC-like, but on trimmed quantiles) ---- + # Conservative bootstrap term is the mean of kept target quantiles (pre-entropy) + info["target_min_q_mean"] = top_quantile_target_q_values.mean().item() + + # --- Soft target decomposition (SAC-specific) --- + # min_target_q_mean: the conservative bootstrap value from twin critics (pre-entropy) + # entropy_term_mean: magnitude of entropy regularization in the target (alpha * log_pi is usually negative) + # soft_target_value_mean: the exact term used inside the Bellman target before reward/discount + # alpha_log_pi is typically negative; entropy_bonus is typically positive + alpha_log_pi = self.alpha * next_log_pi + # this is what gets ADDED to minQ in the target + entropy_bonus = -self.alpha * next_log_pi + + soft_target_value = top_quantile_target_q_values + entropy_bonus + + info["alpha_log_pi_mean"] = alpha_log_pi.mean().item() + info["entropy_bonus_mean"] = entropy_bonus.mean().item() + info["soft_target_value_mean"] = soft_target_value.mean().item() + + # ---- Bellman target scale ---- + info["q_target_mean"] = q_target.mean().item() + info["q_target_std"] = q_target.std(unbiased=False).item() + + # ---- Current quantiles scale ---- + q_flat = q_values.reshape(batch_size, -1) # (B, Q_total) + info["q_mean"] = q_flat.mean().item() + info["q_std"] = q_flat.std(unbiased=False).item() + + # ---- Quantile trimming diagnostics (TQC-specific) ---- + # If dropped quantiles are very far above kept ones, trimming is doing real work + # - kept_mean: mean of the quantiles we *keep* (conservative estimate used for learning) + # - dropped_mean: mean of the high quantiles we *drop* (optimistic tail) + # - drop_gap_mean: how far the dropped tail sits above the kept mass + # Large gap => trimming is actively removing optimistic tail pressure. + info["tqc_kept_mean"] = top_quantile_q_values.mean().item() + dropped_q_values = sorted_q_values[:, drop_q_range:] + kept_q_mean_per_sample = top_quantile_q_values.mean(dim=1) + dropped_q_mean_per_sample = dropped_q_values.mean(dim=1) + + info["tqc_dropped_mean"] = dropped_q_values.mean().item() + info["tqc_drop_gap_mean"] = ( + (dropped_q_mean_per_sample - kept_q_mean_per_sample).mean().item() + ) + + # Same idea on target side + info["tqc_target_kept_mean"] = top_quantile_target_q_values.mean().item() + dropped_target_q_values = sorted_target_q_values[:, drop_q_range:] + kept_target_mean_per_sample = top_quantile_target_q_values.mean(dim=1) + dropped_target_mean_per_sample = dropped_target_q_values.mean(dim=1) + + info["tqc_target_dropped_mean"] = dropped_target_q_values.mean().item() + info["tqc_target_drop_gap_mean"] = ( + (dropped_target_mean_per_sample - kept_target_mean_per_sample) + .mean() + .item() + ) + + # ---- Quantile spread (uncertainty/sharpness proxies) ---- + # IQR across kept quantiles: (q75 - q25) per sample + q25 = top_quantile_q_values.quantile(0.25, dim=1) + q50 = top_quantile_q_values.quantile(0.50, dim=1) + q75 = top_quantile_q_values.quantile(0.75, dim=1) + iqr = q75 - q25 + + info["q_iqr_mean"] = iqr.mean().item() + info["q_iqr_p95"] = iqr.quantile(0.95).item() + info["q_median_mean"] = q50.mean().item() + + # Same IQR on target kept quantiles + tq25 = top_quantile_target_q_values.quantile(0.25, dim=1) + tq50 = top_quantile_target_q_values.quantile(0.50, dim=1) + tq75 = top_quantile_target_q_values.quantile(0.75, dim=1) + tiqr = tq75 - tq25 + + info["target_q_iqr_mean"] = tiqr.mean().item() + info["target_q_iqr_p95"] = tiqr.quantile(0.95).item() + info["target_q_median_mean"] = tq50.mean().item() + + # ---- TD-error diagnostics (fit quality + tails) ---- + info["td_abs_mean"] = td_error.mean().item() + info["td_abs_std"] = td_error.std(unbiased=False).item() + info["td_abs_p95"] = td_error.quantile(0.95).item() + info["td_abs_max"] = td_error.max().item() + + # Quantile-level TD tail (more sensitive than mean-over-quantiles) + # Per sample: max |TD| among kept quantiles + td_abs_qmax = td_errors_abs.max(dim=1).values + info["td_abs_qmax_mean"] = td_abs_qmax.mean().item() + info["td_abs_qmax_p95"] = td_abs_qmax.quantile(0.95).item() + info["td_abs_qmax_max"] = td_abs_qmax.max().item() return info, priorities @@ -188,13 +279,43 @@ def _update_actor_alpha( states: torch.Tensor, weights: torch.Tensor, # pylint: disable=unused-argument ) -> dict[str, Any]: + info: dict[str, Any] = {} + pi, log_pi, _ = self.actor_net(states) with hlp.evaluating(self.critic_net): - mean_qf_pi = self.critic_net(states, pi).mean(2).mean(1, keepdim=True) + q_quant = self.critic_net(states, pi) + + q_mean_per_critic = q_quant.mean(dim=2) # (B, C) + mean_qf_pi = q_mean_per_critic.mean(dim=1, keepdim=True) # (B, 1) actor_loss = (self.alpha * log_pi - mean_qf_pi).mean() + # --------------------------------------------------------- + # Stochastic Policy Gradient Strength (∇a [α log π(a|s) − Q(s,a)]) + # --------------------------------------------------------- + # Measures how steep the entropy-regularized critic objective is + # w.r.t. the sampled policy actions. + # + # ~0 early -> critic surface and entropy term nearly flat; + # actor receives weak learning signal. + # + # Very large -> critic or entropy term is very sharp around policy + # actions; can lead to unstable or overly aggressive + # actor updates. + dq_da = torch.autograd.grad( + outputs=actor_loss, + inputs=pi, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + with torch.no_grad(): + info["dq_da_abs_mean"] = dq_da.abs().mean().item() + info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item() + info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item() + self.actor_net_optimiser.zero_grad() actor_loss.backward() self.actor_net_optimiser.step() @@ -206,8 +327,57 @@ def _update_actor_alpha( alpha_loss.backward() self.log_alpha_optimizer.step() - info = { - "actor_loss": actor_loss.item(), - "alpha_loss": alpha_loss.item(), - } + with torch.no_grad(): + # --- Policy entropy diagnostics --- + info["log_pi_mean"] = log_pi.mean().item() + info["log_pi_std"] = log_pi.std(unbiased=False).item() + + # --- Action magnitude/saturation --- + info["pi_action_abs_mean"] = pi.abs().mean().item() + info["pi_action_std"] = pi.std(unbiased=False).item() + info["pi_action_saturation_frac"] = (pi.abs() > 0.95).float().mean().item() + + # --- On-policy critic signal (TQC uses mean over critics+quantiles) --- + info["mean_qf_pi_mean"] = mean_qf_pi.mean().item() + info["mean_qf_pi_std"] = mean_qf_pi.std(unbiased=False).item() + + # --- Critic disagreement at policy actions (TQC analogue of twin-gap) --- + # Useful dispersion metrics (TQC analogue of "twin gap") + q_std_across_critics = q_mean_per_critic.std(dim=1, unbiased=False) # (B,) + q_range_across_critics = ( + q_mean_per_critic.max(dim=1).values + - q_mean_per_critic.min(dim=1).values + ) # (B,) + + # Quantile spread (distributional uncertainty / sharpness) + # IQR across quantiles after averaging critics + q_mean_across_critics = q_quant.mean(dim=1) # (B, N) + q_q25 = q_mean_across_critics.quantile(0.25, dim=1) # (B,) + q_q50 = q_mean_across_critics.quantile(0.50, dim=1) # (B,) + q_q75 = q_mean_across_critics.quantile(0.75, dim=1) # (B,) + q_iqr = q_q75 - q_q25 # (B,) + + info["q_pi_critics_std_mean"] = q_std_across_critics.mean().item() + info["q_pi_critics_std_p95"] = q_std_across_critics.quantile(0.95).item() + info["q_pi_critics_range_mean"] = q_range_across_critics.mean().item() + info["q_pi_critics_range_p95"] = q_range_across_critics.quantile( + 0.95 + ).item() + + # --- Quantile spread (distributional sharpness / uncertainty proxy) --- + info["q_pi_quantile_iqr_mean"] = q_iqr.mean().item() + info["q_pi_quantile_iqr_p95"] = q_iqr.quantile(0.95).item() + info["q_pi_quantile_median_mean"] = q_q50.mean().item() + + # --- Entropy gap (alpha tuning health) --- + entropy_gap = -(log_pi + self.target_entropy) + info["entropy_gap_mean"] = entropy_gap.mean().item() + info["entropy_gap_std"] = entropy_gap.std(unbiased=False).item() + + # --- Losses / temperature --- + info["actor_loss"] = actor_loss.item() + info["alpha_loss"] = alpha_loss.item() + info["alpha"] = self.alpha.item() + info["log_alpha"] = self.log_alpha.item() + return info diff --git a/cares_reinforcement_learning/algorithm/value/C51.py b/cares_reinforcement_learning/algorithm/value/C51.py index 25293353..8746d7db 100644 --- a/cares_reinforcement_learning/algorithm/value/C51.py +++ b/cares_reinforcement_learning/algorithm/value/C51.py @@ -49,6 +49,8 @@ with fixed support and projection. """ +from typing import Any + import torch from cares_reinforcement_learning.algorithm.value import DQN @@ -83,7 +85,7 @@ def _compute_loss( next_states_tensor: torch.Tensor, dones_tensor: torch.Tensor, batch_size: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, dict[str, Any]]: with torch.no_grad(): if self.use_double_dqn: # Double DQN @@ -129,4 +131,55 @@ def _compute_loss( log_p = torch.log(dist[range(batch_size), actions_tensor]) elementwise_loss = -(proj_dist * log_p).sum(1) - return elementwise_loss + info: dict[str, Any] = {} + + # ----------------------- + # Logging / diagnostics (C51) + # ----------------------- + with torch.no_grad(): + # dist: [B, A, N] + # proj_dist: [B, N] + # chosen pred distribution: [B, N] + pred_dist = dist[range(batch_size), actions_tensor] # [B, N] + + # --- Expected values (scalar Q) --- + # support: [N] + q_pred = (pred_dist * self.support).sum(dim=1) # [B] + q_targ = (proj_dist * self.support).sum(dim=1) # [B] + + td = q_pred - q_targ + info["q_pred_mean"] = q_pred.mean().item() + info["q_pred_std"] = q_pred.std().item() + info["q_target_mean"] = q_targ.mean().item() + info["q_target_std"] = q_targ.std().item() + + info["td_mean"] = td.mean().item() + info["td_std"] = td.std().item() + info["td_abs_mean"] = td.abs().mean().item() + + # --- Greedy action distribution under expected value --- + q_all = (dist * self.support.view(1, 1, -1)).sum(dim=2) # [B, A] + greedy_actions = q_all.argmax(dim=1) # [B] + num_actions = self.network.num_actions + counts = torch.bincount(greedy_actions, minlength=num_actions).float() + probs = counts / counts.sum().clamp(min=1.0) + entropy_actions = -(probs * (probs + 1e-12).log()).sum() + info["greedy_action_entropy"] = entropy_actions.item() + info["greedy_action_max_prob"] = probs.max().item() + info["greedy_action_probs"] = probs.cpu().tolist() + + # --- Distribution entropy (collapse / over-uncertainty) --- + pred_ent = -(pred_dist * (pred_dist + 1e-12).log()).sum(dim=1) # [B] + targ_ent = -(proj_dist * (proj_dist + 1e-12).log()).sum(dim=1) # [B] + info["pred_dist_entropy_mean"] = pred_ent.mean().item() + info["pred_dist_entropy_std"] = pred_ent.std().item() + info["target_dist_entropy_mean"] = targ_ent.mean().item() + + # --- Support saturation / projection health --- + info["proj_mass_vmin"] = proj_dist[:, 0].mean().item() + info["proj_mass_vmax"] = proj_dist[:, -1].mean().item() + info["proj_mass_on_bounds"] = ( + (proj_dist[:, 0] + proj_dist[:, -1]).mean().item() + ) + + return elementwise_loss, info diff --git a/cares_reinforcement_learning/algorithm/value/DQN.py b/cares_reinforcement_learning/algorithm/value/DQN.py index fb943456..2e2ee20a 100644 --- a/cares_reinforcement_learning/algorithm/value/DQN.py +++ b/cares_reinforcement_learning/algorithm/value/DQN.py @@ -65,7 +65,7 @@ from cares_reinforcement_learning.types.episode import EpisodeContext from cares_reinforcement_learning.types.observation import SARLObservation from cares_reinforcement_learning.util.configurations import DQNConfig -from cares_reinforcement_learning.util.helpers import EpsilonScheduler +from cares_reinforcement_learning.util.helpers import LinearScheduler class DQN(SARLAlgorithm[int]): @@ -88,12 +88,12 @@ def __init__( self.max_grad_norm = config.max_grad_norm # Epsilon - self.epsilon_scheduler = EpsilonScheduler( - start_epsilon=config.start_epsilon, - end_epsilon=config.end_epsilon, + self.epsilon_scheduler = LinearScheduler( + start_value=config.start_epsilon, + end_value=config.end_epsilon, decay_steps=config.decay_steps, ) - self.epsilon = self.epsilon_scheduler.get_epsilon(0) + self.epsilon = self.epsilon_scheduler.get_value(0) # Double DQN self.use_double_dqn = config.use_double_dqn @@ -165,8 +165,9 @@ def _compute_loss( next_states_tensor: torch.Tensor, dones_tensor: torch.Tensor, batch_size: int, # pylint: disable=unused-argument - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, dict[str, Any]]: """Computes the elementwise loss for DQN. If use_double_dqn=True, applies Double DQN logic.""" + q_values = self.network(states_tensor) next_q_values_target = self.target_network(next_states_tensor) @@ -189,7 +190,48 @@ def _compute_loss( ) elementwise_loss = F.mse_loss(best_q_values, q_target, reduction="none") - return elementwise_loss + # ----------------------- + # Logging / diagnostics (DQN) + # ----------------------- + with torch.no_grad(): + # Action histogram (batch-based) + greedy_actions = q_values.argmax(dim=1) # [B] + num_actions = self.network.num_actions + counts = torch.bincount(greedy_actions, minlength=num_actions).float() + probs = counts / counts.sum().clamp(min=1.0) + + # Entropy: 0 = totally collapsed, higher = more spread + entropy = -(probs * (probs + 1e-12).log()).sum() + + td_error = best_q_values - q_target # signed, shape [B] + + # Logging Statistics + info: dict[str, Any] = {} + info["greedy_action_entropy"] = entropy.item() + info["greedy_action_max_prob"] = probs.max().item() + # Optional: full distribution (can be logged as list) + info["greedy_action_probs"] = probs.cpu().tolist() + + info["td_error_mean"] = td_error.mean().item() + info["td_error_std"] = td_error.std().item() + info["td_error_abs_mean"] = td_error.abs().mean().item() + + info["q_value_mean"] = best_q_values.mean().item() + info["q_value_max"] = best_q_values.max().item() + info["q_value_std"] = best_q_values.std().item() + info["q_value_next_mean"] = best_next_q_values.mean().item() + info["q_value_next_max"] = best_next_q_values.max().item() + info["q_value_next_std"] = best_next_q_values.std().item() + info["q_target_mean"] = q_target.mean().item() + info["q_target_max"] = q_target.max().item() + info["q_target_std"] = q_target.std().item() + info["reward_mean"] = rewards_tensor.mean().item() + info["reward_std"] = rewards_tensor.std().item() + info["overestimation_gap"] = ( + (q_values.max(dim=1).values - best_next_q_values).mean().item() + ) + + return elementwise_loss, info def train( self, @@ -202,7 +244,7 @@ def train( training_step = episode_context.training_step - self.epsilon = self.epsilon_scheduler.get_epsilon(training_step) + self.epsilon = self.epsilon_scheduler.get_value(training_step) if len(memory_buffer) < self.batch_size: return {} @@ -235,7 +277,7 @@ def train( weights_tensor = weights_tensor.view(-1) # Calculate loss - overriden by C51 - elementwise_loss = self._compute_loss( + elementwise_loss, train_info = self._compute_loss( observation_tensor.vector_state_tensor, actions_tensor, rewards_tensor, @@ -243,6 +285,7 @@ def train( dones_tensor, sample_size, ) + info |= train_info if self.use_per_buffer: # Update the Priorities @@ -254,6 +297,11 @@ def train( .flatten() ) + info["per_priority_mean"] = priorities.mean() + info["per_priority_max"] = priorities.max() + info["per_priority_min"] = priorities.min() + info["per_priority_std"] = priorities.std() + memory_buffer.update_priorities(indices, priorities) loss = torch.mean(elementwise_loss * weights_tensor) diff --git a/cares_reinforcement_learning/algorithm/value/NoisyNet.py b/cares_reinforcement_learning/algorithm/value/NoisyNet.py index ac2fc734..b77b5071 100644 --- a/cares_reinforcement_learning/algorithm/value/NoisyNet.py +++ b/cares_reinforcement_learning/algorithm/value/NoisyNet.py @@ -57,15 +57,18 @@ from cares_reinforcement_learning.algorithm.value import DQN from cares_reinforcement_learning.memory.memory_buffer import SARLMemoryBuffer -from cares_reinforcement_learning.networks.NoisyNet import Network +from cares_reinforcement_learning.networks.NoisyNet import BaseNoisyNetwork from cares_reinforcement_learning.types.episode import EpisodeContext from cares_reinforcement_learning.util.configurations import NoisyNetConfig class NoisyNet(DQN): + network: BaseNoisyNetwork + target_network: BaseNoisyNetwork + def __init__( self, - network: Network, + network: BaseNoisyNetwork, config: NoisyNetConfig, device: torch.device, ): @@ -81,5 +84,6 @@ def train( episode_context: EpisodeContext, ) -> dict: info = super().train(memory_buffer, episode_context) + info.update(self.network.noise_stats()) self._reset_noise() return info diff --git a/cares_reinforcement_learning/algorithm/value/QMIX.py b/cares_reinforcement_learning/algorithm/value/QMIX.py index 9862fa4a..d65494f9 100644 --- a/cares_reinforcement_learning/algorithm/value/QMIX.py +++ b/cares_reinforcement_learning/algorithm/value/QMIX.py @@ -66,7 +66,7 @@ from cares_reinforcement_learning.types.episode import EpisodeContext from cares_reinforcement_learning.types.observation import MARLObservation from cares_reinforcement_learning.util.configurations import QMIXConfig -from cares_reinforcement_learning.util.helpers import EpsilonScheduler +from cares_reinforcement_learning.util.helpers import LinearScheduler class QMIX(MARLAlgorithm[list[int]]): @@ -97,12 +97,12 @@ def __init__( self.num_actions = network.num_actions # Epsilon - self.epsilon_scheduler = EpsilonScheduler( - start_epsilon=config.start_epsilon, - end_epsilon=config.end_epsilon, + self.epsilon_scheduler = LinearScheduler( + start_value=config.start_epsilon, + end_value=config.end_epsilon, decay_steps=config.decay_steps, ) - self.epsilon = self.epsilon_scheduler.get_epsilon(0) + self.epsilon = self.epsilon_scheduler.get_value(0) # Double DQN self.use_double_dqn = config.use_double_dqn @@ -173,16 +173,6 @@ def act( return ActionSample(action=actions, source="policy") - # def _calculate_value(self, state: np.ndarray, action: int) -> float: # type: ignore[override] - # state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device) - # state_tensor = state_tensor.unsqueeze(0) - - # with torch.no_grad(): - # q_values = self.network(state_tensor) - # q_value = q_values[0][action].item() - - # return q_value - def _compute_loss( self, obs_tensors: torch.Tensor, @@ -190,14 +180,13 @@ def _compute_loss( states_tensors: torch.Tensor, next_states_tensors: torch.Tensor, actions_tensors: torch.Tensor, + avail_actions: torch.Tensor, next_avail_actions_tensors: torch.Tensor, rewards_tensors: torch.Tensor, dones_tensors: torch.Tensor, ) -> tuple[torch.Tensor, dict[str, float]]: """Computes the elementwise loss for QMIX. If use_double_dqn=True, applies Double DQN logic.""" - loss_info: dict[str, float] = {} - q_values = self.network(obs_tensors) next_q_values_target = self.target_network(next_obs_tensors) @@ -239,11 +228,99 @@ def _compute_loss( elementwise_loss = F.mse_loss(q_total, q_target, reduction="none") - loss_info["td_error_mean"] = elementwise_loss.mean().item() - loss_info["td_error_std"] = elementwise_loss.std().item() - loss_info["q_total_mean"] = q_total.mean().item() - loss_info["q_target_mean"] = q_target.mean().item() - loss_info["q_next_mean"] = best_next_q_values.mean().item() + # ---------------------------- + # Logging / diagnostics (QMIX) + # ---------------------------- + loss_info: dict[str, float] = {} + with torch.no_grad(): + td_total = (q_total - q_target).view(-1) # [B] + + loss_info["td_total_mean"] = td_total.mean().item() + loss_info["td_total_std"] = td_total.std().item() + loss_info["td_total_abs_mean"] = td_total.abs().mean().item() + loss_info["mse_total_mean"] = elementwise_loss.mean().item() + + loss_info["q_total_mean"] = q_total.mean().item() + loss_info["q_target_mean"] = q_target.mean().item() + + # Per-agent utilities (chosen + bootstrap) + loss_info["q_i_chosen_mean"] = best_q_values.mean().item() + loss_info["q_i_chosen_std"] = best_q_values.std().item() + loss_info["q_i_next_mean"] = best_next_q_values.mean().item() + loss_info["q_i_next_std"] = best_next_q_values.std().item() + + # --- Current-state action masking for meaningful greedy/diversity metrics --- + avail = avail_actions.to(q_values.device) # [B, n_agents, n_actions] + masked_q_values = q_values.masked_fill(avail == 0, -1e9) + + # Max feasible utility per agent + q_i_max = masked_q_values.max(dim=2).values # [B, n_agents] + loss_info["q_i_max_mean"] = q_i_max.mean().item() + loss_info["q_i_max_std"] = q_i_max.std().item() + + # Mixer vs sum baseline + sum_q_i = best_q_values.sum(dim=1, keepdim=True) # [B, 1] + diff = q_total - sum_q_i # [B, 1] + + loss_info["sum_q_i_mean"] = sum_q_i.mean().item() + loss_info["sum_q_i_abs_mean"] = sum_q_i.abs().mean().item() + + loss_info["q_total_minus_sum_q_i_mean"] = diff.mean().item() + loss_info["q_total_minus_sum_q_i_std"] = diff.std().item() + loss_info["q_total_minus_sum_q_i_abs_mean"] = diff.abs().mean().item() + + # Scale-stable "how big is mixer output vs sum" using absolute means + loss_info["q_total_abs_mean"] = q_total.abs().mean().item() + loss_info["q_total_abs_over_sum_q_i_abs_mean"] = ( + q_total.abs().mean() / (sum_q_i.abs().mean() + 1e-6) + ).item() + + # Correlation between q_total and sum_q_i (are they at least monotonic-ish?) + sum_centered = sum_q_i - sum_q_i.mean() + qt_centered = q_total - q_total.mean() + corr = (sum_centered * qt_centered).mean() / ( + (sum_centered.pow(2).mean().sqrt() * qt_centered.pow(2).mean().sqrt()) + + 1e-6 + ) + loss_info["q_total_sum_q_i_corr"] = corr.item() + + # Availability constraints (next-state) + loss_info["next_avail_action_frac"] = ( + next_avail_actions_tensors.float().mean().item() + ) + loss_info["next_avail_actions_per_agent"] = ( + next_avail_actions_tensors.float().sum(dim=2).mean().item() + ) + + # Sanity: are actions in replay valid under current avail_actions? + chosen_is_valid = avail.gather(2, actions_tensors.unsqueeze(-1)).squeeze( + -1 + ) # [B, n_agents] + loss_info["invalid_action_frac"] = ( + (chosen_is_valid == 0).float().mean().item() + ) + loss_info["no_action_available_frac"] = ( + (avail.float().sum(dim=2) == 0).float().mean().item() + ) + + # Greedy action diversity (MASKED, feasible actions only) + greedy_actions = masked_q_values.argmax(dim=2) # [B, n_agents] + entropies = [] + max_probs = [] + for ag in range(self.num_agents): + counts = torch.bincount( + greedy_actions[:, ag], minlength=self.num_actions + ).float() + probs = counts / counts.sum().clamp(min=1.0) + entropies.append(-(probs * (probs + 1e-12).log()).sum()) + max_probs.append(probs.max()) + + loss_info["greedy_action_entropy_mean_agents"] = ( + torch.stack(entropies).mean().item() + ) + loss_info["greedy_action_max_prob_mean_agents"] = ( + torch.stack(max_probs).mean().item() + ) return elementwise_loss, loss_info @@ -258,7 +335,7 @@ def train( training_step = episode_context.training_step - self.epsilon = self.epsilon_scheduler.get_epsilon(training_step) + self.epsilon = self.epsilon_scheduler.get_value(training_step) if len(memory_buffer) < self.batch_size: return {} @@ -303,6 +380,7 @@ def train( next_states_tensors=next_observation_tensor.global_state_tensor, actions_tensors=actions_tensor, rewards_tensors=rewards_tensor, + avail_actions=observation_tensor.avail_actions_tensor, next_avail_actions_tensors=next_observation_tensor.avail_actions_tensor, dones_tensors=dones_tensor, ) @@ -318,6 +396,11 @@ def train( .flatten() ) + info["per_priority_mean"] = priorities.mean() + info["per_priority_max"] = priorities.max() + info["per_priority_min"] = priorities.min() + info["per_priority_std"] = priorities.std() + memory_buffer.update_priorities(indices, priorities) loss = torch.mean(elementwise_loss * weights_tensor) diff --git a/cares_reinforcement_learning/algorithm/value/QRDQN.py b/cares_reinforcement_learning/algorithm/value/QRDQN.py index c96307cc..21455b83 100644 --- a/cares_reinforcement_learning/algorithm/value/QRDQN.py +++ b/cares_reinforcement_learning/algorithm/value/QRDQN.py @@ -45,6 +45,8 @@ QR-DQN = DQN + quantile-based distributional value learning. """ +from typing import Any + import torch from cares_reinforcement_learning.algorithm.value import DQN @@ -97,7 +99,7 @@ def _compute_loss( next_states_tensor: torch.Tensor, dones_tensor: torch.Tensor, batch_size: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, dict[str, Any]]: # Predicted Q-value quantiles for current state current_quantile_values = self.network.calculate_quantiles(states_tensor) @@ -156,4 +158,89 @@ def _compute_loss( dim=1, keepdim=True ) - return element_wise_loss + # ----------------------- + # Logging / diagnostics (QR-DQN) + # ----------------------- + with torch.no_grad(): + info: dict[str, Any] = {} + + # Shapes (typical): + # current_quantile_values: [B, A, N] + # current_action_q_values: [B, 1, N] (gathered at actions taken) + # target_q_values: [B, N] + + # Scalar Q-values via mean over quantiles (DQN-like view) + q_mean = current_quantile_values.mean(dim=-1) # [B, A] + greedy_actions = q_mean.argmax(dim=1) # [B] + + # Action histogram (batch-based, greedy under mean-Q) + num_actions = self.network.num_actions + counts = torch.bincount(greedy_actions, minlength=num_actions).float() + probs = counts / counts.sum().clamp(min=1.0) + entropy = -(probs * (probs + 1e-12).log()).sum() + + info["greedy_action_entropy"] = entropy.item() + info["greedy_action_max_prob"] = probs.max().item() + info["greedy_action_probs"] = probs.cpu().tolist() + + # Chosen-action distribution (predicted and target) + pred_quantiles = current_action_q_values.squeeze(1) # [B, N] + targ_quantiles = target_q_values # [B, N] + + pred_mean = pred_quantiles.mean(dim=1) # [B] + targ_mean = targ_quantiles.mean(dim=1) # [B] + + # Scalar Q estimate from QR (mean over quantiles) + info["pred_mean"] = pred_mean.mean().item() + info["pred_mean_std"] = pred_mean.std().item() + info["pred_mean_max"] = pred_mean.max().item() + info["pred_mean_min"] = pred_mean.min().item() + + info["target_mean"] = targ_mean.mean().item() + info["target_mean_std"] = targ_mean.std().item() + info["target_mean_max"] = targ_mean.max().item() + info["target_mean_min"] = targ_mean.min().item() + + # Scalar TD error (mean-return TD) + td_mean = pred_mean - targ_mean # [B] + info["td_mean"] = td_mean.mean().item() + info["td_std"] = td_mean.std().item() + info["td_abs_mean"] = td_mean.abs().mean().item() + + # Distributional TD error (quantile-wise) + td_q = pred_quantiles - targ_quantiles # [B, N] + info["td_q_abs_mean"] = td_q.abs().mean().item() + info["td_q_abs_p95"] = td_q.abs().quantile(0.95).item() + info["td_q_mean"] = td_q.mean().item() + info["td_q_std"] = td_q.std().item() + + # Distribution spread / uncertainty (QR-specific) + # Std over quantiles for the chosen action distribution + pred_spread = pred_quantiles.std(dim=1) # [B] + targ_spread = targ_quantiles.std(dim=1) # [B] + info["pred_quantile_std_mean"] = pred_spread.mean().item() + info["pred_quantile_std_p95"] = pred_spread.quantile(0.95).item() + info["targ_quantile_std_mean"] = targ_spread.mean().item() + info["targ_quantile_std_p95"] = targ_spread.quantile(0.95).item() + + # IQR (more robust than std) + q25 = pred_quantiles.quantile(0.25, dim=1) + q75 = pred_quantiles.quantile(0.75, dim=1) + info["pred_quantile_iqr_mean"] = (q75 - q25).mean().item() + + # Tail means (risk-sensitive view; useful to see if distribution shifts correctly) + # choose a small tail fraction + tail_k = max(1, self.quantiles // 10) # ~10% tail + info["pred_cvar_low_mean"] = pred_quantiles[:, :tail_k].mean().item() + info["pred_cvar_high_mean"] = pred_quantiles[:, -tail_k:].mean().item() + info["targ_cvar_low_mean"] = targ_quantiles[:, :tail_k].mean().item() + info["targ_cvar_high_mean"] = targ_quantiles[:, -tail_k:].mean().item() + + # Quantile huber loss stats (your actual objective) + # element_wise_quantile_huber_loss is returned by your helper. + # Log its mean/std/max for stability monitoring. + info["quantile_huber_mean"] = element_wise_quantile_huber_loss.mean().item() + info["quantile_huber_std"] = element_wise_quantile_huber_loss.std().item() + info["quantile_huber_max"] = element_wise_quantile_huber_loss.max().item() + + return element_wise_loss, info diff --git a/cares_reinforcement_learning/algorithm/value/Rainbow.py b/cares_reinforcement_learning/algorithm/value/Rainbow.py index 8bf2866a..a24e0da2 100644 --- a/cares_reinforcement_learning/algorithm/value/Rainbow.py +++ b/cares_reinforcement_learning/algorithm/value/Rainbow.py @@ -93,5 +93,6 @@ def train( episode_context: EpisodeContext, ) -> dict[str, Any]: info = super().train(memory_buffer, episode_context) + info.update(self.network.noise_stats()) self._reset_noise() return info diff --git a/cares_reinforcement_learning/networks/DQN/network.py b/cares_reinforcement_learning/networks/DQN/network.py index 05ffe31c..f1609b9a 100644 --- a/cares_reinforcement_learning/networks/DQN/network.py +++ b/cares_reinforcement_learning/networks/DQN/network.py @@ -1,7 +1,10 @@ +from typing import Any + import torch from torch import nn from cares_reinforcement_learning.networks.mlp_architecture import MLP +from cares_reinforcement_learning.networks.noisy_linear import NoisyLinear from cares_reinforcement_learning.util.configurations import DQNConfig @@ -21,6 +24,77 @@ def forward(self, state: torch.Tensor) -> torch.Tensor: "BaseDQN is an abstract class and cannot be instantiated directly." ) + @torch.no_grad() + def _module_noise_stats(self, key: str, layer: nn.Module) -> dict[str, Any]: + sigma_vals = [] + sigma_max_vals = [] + sigma_mu_ratios = [] + weight_noise_rms = [] + bias_noise_rms = [] + + layer_stats: dict[str, Any] = {} + layer_i = 0 + + for m in layer.modules(): + if not isinstance(m, NoisyLinear): + continue + + # sigma stats + w_sigma = m.weight_sigma.detach().abs() + b_sigma = m.bias_sigma.detach().abs() + sigma_all = torch.cat([w_sigma.flatten(), b_sigma.flatten()]) + sigma_vals.append(sigma_all) + + sigma_max_vals.append(float(sigma_all.max().item())) + + # sigma / mu ratio (scale-invariant-ish) + w_mu = m.weight_mu.detach().abs() + b_mu = m.bias_mu.detach().abs() + mu_all = torch.cat([w_mu.flatten(), b_mu.flatten()]).clamp(min=1e-12) + sigma_mu_ratios.append((sigma_all.mean() / mu_all.mean()).item()) + + # actual injected noise magnitude right now (depends on epsilon buffers) + w_noise = (m.weight_sigma * m.weight_epsilon).detach() + b_noise = (m.bias_sigma * m.bias_epsilon).detach() + weight_noise_rms.append(w_noise.pow(2).mean().sqrt().item()) + bias_noise_rms.append(b_noise.pow(2).mean().sqrt().item()) + + # optional per-layer (sometimes too noisy for logs) + layer_stats[f"{key}_noisy_layer_{layer_i}_sigma_mean"] = float( + sigma_all.mean().item() + ) + layer_stats[f"{key}_noisy_layer_{layer_i}_sigma_mu_ratio"] = float( + sigma_mu_ratios[-1] + ) + layer_stats[f"{key}_noisy_layer_{layer_i}_weight_noise_rms"] = float( + weight_noise_rms[-1] + ) + layer_i += 1 + + if not sigma_vals: + return {} + + sigma_all = torch.cat([v.flatten() for v in sigma_vals]) + out: dict[str, Any] = { + f"{key}_noisy_sigma_mean": float(sigma_all.mean().item()), + f"{key}_noisy_sigma_std": float(sigma_all.std().item()), + f"{key}_noisy_sigma_max": float(max(sigma_max_vals)), + f"{key}_noisy_sigma_mu_ratio_mean": float( + sum(sigma_mu_ratios) / len(sigma_mu_ratios) + ), + f"{key}_noisy_weight_noise_rms_mean": float( + sum(weight_noise_rms) / len(weight_noise_rms) + ), + f"{key}_noisy_bias_noise_rms_mean": float( + sum(bias_noise_rms) / len(bias_noise_rms) + ), + } + + # enable if you want layer-by-layer drilldown + out.update(layer_stats) + + return out + class BaseDQN(BaseNetwork): def __init__( diff --git a/cares_reinforcement_learning/networks/NoisyNet/network.py b/cares_reinforcement_learning/networks/NoisyNet/network.py index 5cb2bb6d..42375e33 100644 --- a/cares_reinforcement_learning/networks/NoisyNet/network.py +++ b/cares_reinforcement_learning/networks/NoisyNet/network.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from torch import nn @@ -17,9 +19,12 @@ def forward(self, state: torch.Tensor) -> torch.Tensor: def reset_noise(self): for module in self.network.modules(): - if hasattr(module, "reset_noise"): + if isinstance(module, NoisyLinear): module.reset_noise() + def noise_stats(self) -> dict[str, Any]: + return self._module_noise_stats("network", self.network) + class DefaultNetwork(BaseNoisyNetwork): def __init__(self, observation_size: int, num_actions: int): diff --git a/cares_reinforcement_learning/networks/Rainbow/network.py b/cares_reinforcement_learning/networks/Rainbow/network.py index 88b6bad1..8527184d 100644 --- a/cares_reinforcement_learning/networks/Rainbow/network.py +++ b/cares_reinforcement_learning/networks/Rainbow/network.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from torch import nn @@ -61,17 +63,26 @@ def dist(self, state: torch.Tensor) -> torch.Tensor: def reset_noise(self): for module in self.feature_layer.modules(): - if hasattr(module, "reset_noise"): + if isinstance(module, NoisyLinear): module.reset_noise() for module in self.value_stream.modules(): - if hasattr(module, "reset_noise"): + if isinstance(module, NoisyLinear): module.reset_noise() for module in self.advantage_stream.modules(): - if hasattr(module, "reset_noise"): + if isinstance(module, NoisyLinear): module.reset_noise() + def noise_stats(self) -> dict[str, Any]: + stats = {} + stats.update(self._module_noise_stats("feature_layer", self.feature_layer)) + stats.update(self._module_noise_stats("value_stream", self.value_stream)) + stats.update( + self._module_noise_stats("advantage_stream", self.advantage_stream) + ) + return stats + # This is the default base network for DQN for reference and testing of default network configurations class DefaultNetwork(BaseRainbowNetwork): diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index f9d76495..d0de7221 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -447,7 +447,7 @@ class MAPPOConfig(PPOConfig): target_kl: float | None = 0.05 entropy_start: float = 0.01 - entropy_end: float = 0.0 + entropy_end: float = 1e-4 entropy_decay: int = 500000 max_grad_norm: float | None = 0.5 @@ -935,6 +935,11 @@ class DDPGConfig(AlgorithmConfig): actor_lr: float = 1e-4 critic_lr: float = 1e-3 + # Exploration noise + action_noise_start: float = 0.2 + action_noise_end: float = 0.05 + action_noise_decay: int = 1000000 + gamma: float = 0.99 tau: float = 0.005 @@ -1046,16 +1051,16 @@ class TD3Config(AlgorithmConfig): tau: float = 0.005 # Exploration noise - min_action_noise: float = 0.1 - action_noise: float = 0.1 - action_noise_decay: float = 1.0 + action_noise_start: float = 0.1 + action_noise_end: float = 0.1 + action_noise_decay: int = 1 # Target policy smoothing - policy_noise_clip: float = 0.5 + policy_noise_start: float = 0.2 + policy_noise_end: float = 0.2 + policy_noise_decay: int = 1 - min_policy_noise: float = 0.2 - policy_noise: float = 0.2 - policy_noise_decay: float = 1.0 + policy_noise_clip: float = 0.5 policy_update_freq: int = 2 @@ -1343,17 +1348,27 @@ class CTD4Config(TD3Config): tau: float = 0.005 ensemble_size: int = 3 - min_action_noise: float = 0.0 - action_noise: float = 0.1 - action_noise_decay: float = 0.999999 + # Exploration noise + action_noise_start: float = 0.1 + action_noise_end: float = 0.02 + action_noise_decay: int = 1000000 - min_policy_noise: float = 0.0 - policy_noise: float = 0.2 - policy_noise_decay: float = 0.999999 + # Policy smoothing + policy_noise_start: float = 0.2 + policy_noise_end: float = 0.07 + policy_noise_decay: int = 1000000 policy_update_freq: int = 2 - fusion_method: str = "kalman" # kalman, minimum, average + fusion_method: str = ( + "interpolated" # precision, interpolated, covariance, correlated, minimum, average + ) + + kalman_beta_start: float = 0.1 + kalman_beta_end: float = 1.0 + kalman_beta_decay: int = 800000 + + kalman_rho: float = 0.8 class TD7Config(TD3Config): diff --git a/cares_reinforcement_learning/util/helpers.py b/cares_reinforcement_learning/util/helpers.py index 2de4cda0..69b49d13 100644 --- a/cares_reinforcement_learning/util/helpers.py +++ b/cares_reinforcement_learning/util/helpers.py @@ -6,21 +6,35 @@ import torch -class EpsilonScheduler: - def __init__(self, start_epsilon: float, end_epsilon: float, decay_steps: int): - self.start_epsilon = start_epsilon - self.end_epsilon = end_epsilon +class LinearScheduler: + def __init__(self, start_value: float, end_value: float, decay_steps: int): + self.start_value = start_value + self.end_value = end_value self.decay_steps = decay_steps - self.epsilon = start_epsilon + self.value = start_value - def get_epsilon(self, step: int) -> float: + def get_value(self, step: int) -> float: if step < self.decay_steps: - self.epsilon = self.start_epsilon - ( - self.start_epsilon - self.end_epsilon - ) * (step / self.decay_steps) + self.value = self.start_value - (self.start_value - self.end_value) * ( + step / self.decay_steps + ) else: - self.epsilon = self.end_epsilon - return self.epsilon + self.value = self.end_value + return self.value + + +class ExponentialScheduler: + def __init__(self, start_value: float, end_value: float, decay_steps: int): + self.start = start_value + self.end = end_value + self.decay_steps = decay_steps + if decay_steps == 0: + self.gamma = 0.0 + else: + self.gamma = (end_value / start_value) ** (1.0 / decay_steps) + + def get_value(self, step: int) -> float: + return max(self.end, self.start * (self.gamma**step)) def get_device() -> torch.device: diff --git a/cares_reinforcement_learning/util/plotter.py b/cares_reinforcement_learning/util/plotter.py index 19a79dbd..c525dabc 100644 --- a/cares_reinforcement_learning/util/plotter.py +++ b/cares_reinforcement_learning/util/plotter.py @@ -530,8 +530,6 @@ def _read_data( def plot_evaluations(): args = parse_args() - file_name = args["file_name"] if args["file_name"] is not None else args["title"] - x_train = args["x_train"] y_train = args["y_train"] y_train_two = args["y_train_two"] @@ -588,6 +586,7 @@ def plot_evaluations(): result_directories = [x for x in directory if x.is_dir()] title, algorithm, task, label = generate_labels(args, title, model_path) + file_name = args["file_name"] if args["file_name"] is not None else title labels.append(label)