Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 92 additions & 11 deletions cares_reinforcement_learning/algorithm/policy/DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
SARLObservationTensors,
)
from cares_reinforcement_learning.util.configurations import DDPGConfig
from cares_reinforcement_learning.util.helpers import ExponentialScheduler


class DDPG(SARLAlgorithm[np.ndarray]):
Expand All @@ -88,27 +89,47 @@ def __init__(
self.gamma = config.gamma
self.tau = config.tau

# Action noise
self.action_noise_scheduler = ExponentialScheduler(
start_value=config.action_noise_start,
end_value=config.action_noise_end,
decay_steps=config.action_noise_decay,
)
self.action_noise = self.action_noise_scheduler.get_value(0)

self.action_num = self.actor_net.num_actions

self.actor_net_optimiser = torch.optim.Adam(
self.actor_net.parameters(), lr=config.actor_lr
)
self.critic_net_optimiser = torch.optim.Adam(
self.critic_net.parameters(), lr=config.critic_lr
)

# TODO add action noise for exploration
self.learn_counter = 0

def act(
self, observation: SARLObservation, evaluation: bool = False
) -> ActionSample[np.ndarray]:
# pylint: disable-next=unused-argument
self.actor_net.eval()

state = observation.vector_state

self.actor_net.eval()
with torch.no_grad():
state_tensor = torch.FloatTensor(state).to(self.device)
state_tensor = state_tensor.unsqueeze(0)
action = self.actor_net(state_tensor)
action = action.cpu().data.numpy().flatten()
if not evaluation:
# this is part the DDPG too, add noise to the action
noise = np.random.normal(
0, scale=self.action_noise, size=self.action_num
).astype(np.float32)
action = action + noise
action = np.clip(action, -1, 1)

self.actor_net.train()

return ActionSample(action=action, source="policy")

def _update_critic(
Expand All @@ -119,6 +140,8 @@ def _update_critic(
next_states: torch.Tensor,
dones: torch.Tensor,
) -> dict[str, Any]:
info: dict[str, Any] = {}

with torch.no_grad():
self.target_actor_net.eval()
next_actions = self.target_actor_net(next_states)
Expand All @@ -134,25 +157,79 @@ def _update_critic(
critic_loss.backward()
self.critic_net_optimiser.step()

info = {
"critic_loss": critic_loss.item(),
}
with torch.no_grad():
td = q_values - q_target

# --- Q statistics ---
info["q_mean"] = q_values.mean().item()
info["q_std"] = q_values.std().item()

# --- Bellman target scale (reward scaling / discount sanity) ---
# If q_target drifts upward without reward improvement, suspect reward_scale, gamma, or instability.
info["q_target_mean"] = q_target.mean().item()
info["q_target_std"] = q_target.std().item()

# --- TD error diagnostics (Bellman fit quality) ---
# td_abs_mean down over time is healthy; persistent growth/spikes often indicate critic instability.
info["td_mean"] = td.mean().item()
info["td_std"] = td.std().item()
info["td_abs_mean"] = td.abs().mean().item()

# --- Losses (optimization progress ---
info["critic_loss"] = critic_loss.item()

return info

def _update_actor(self, states: torch.Tensor) -> dict[str, Any]:
info: dict[str, Any] = {}

self.critic_net.eval()
actions_pred = self.actor_net(states)
actor_q = self.critic_net(states, actions_pred)
actions = self.actor_net(states)
actor_q_values = self.critic_net(states, actions)
self.critic_net.train()

actor_loss = -actor_q.mean()
actor_loss = -actor_q_values.mean()

# ---------------------------------------------------------
# Deterministic Policy Gradient Strength (∇a Q(s,a))
# ---------------------------------------------------------
# Measures how steep the critic surface is w.r.t. actions.
# ~0 early -> critic flat, actor receives no learning signal.
# Very large -> critic overly sharp, can cause unstable actor updates.
dq_da = torch.autograd.grad(
outputs=-actor_q_values.mean(), # NOTE: uses Q-term only, excludes regularizers
inputs=actions,
retain_graph=True, # needed because we will backward (actor_loss) next
create_graph=False, # diagnostic only
allow_unused=False,
)[0]
with torch.no_grad():
info["dq_da_abs_mean"] = dq_da.abs().mean().item()
info["dq_da_norm_mean"] = dq_da.norm(dim=1).mean().item()
info["dq_da_norm_p95"] = dq_da.norm(dim=1).quantile(0.95).item()

self.actor_net_optimiser.zero_grad()
actor_loss.backward()
self.actor_net_optimiser.step()

info = {"actor_loss": actor_loss.item()}
with torch.no_grad():
# Policy Action Health (tanh policies in [-1, 1])
# pi_action_saturation_frac:
# High values (>0.8 early) often mean the actor is slamming bounds,
# reducing effective gradient flow through tanh.
info["pi_action_mean"] = actions.mean().item()
info["pi_action_std"] = actions.std().item()
info["pi_action_abs_mean"] = actions.abs().mean().item()
info["pi_action_saturation_frac"] = (
(actions.abs() > 0.95).float().mean().item()
)

# actor_q_mean should generally increase over training.
# actor_q_std large + unstable may indicate critic inconsistency.
info["actor_loss"] = actor_loss.item()
info["actor_q_mean"] = actor_q_values.mean().item()
info["actor_q_std"] = actor_q_values.std().item()

return info

def update_from_batch(
Expand All @@ -164,9 +241,13 @@ def update_from_batch(
next_observation_tensor: SARLObservationTensors,
dones_tensor: torch.Tensor,
) -> dict[str, Any]:
self.learn_counter += 1

info: dict[str, Any] = {}

# TODO add the action noise for exploration with episode context and some decay mechanism
self.action_noise = self.action_noise_scheduler.get_value(
episode_context.training_step
)

# Update Critic
critic_info = self._update_critic(
Expand Down
6 changes: 6 additions & 0 deletions cares_reinforcement_learning/algorithm/policy/MADDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ def train(
# ---------------------------------------------------------
# Update each agent
# ---------------------------------------------------------

# Update action noise for exploration (decayed over training)
current_agent.action_noise = current_agent.action_noise_scheduler.get_value(
episode_context.training_step
)

(
observation_tensor,
actions_tensor,
Expand Down
12 changes: 6 additions & 6 deletions cares_reinforcement_learning/algorithm/policy/MAPPO.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
SARLObservation,
)
from cares_reinforcement_learning.util.configurations import MAPPOConfig
from cares_reinforcement_learning.util.helpers import EpsilonScheduler
from cares_reinforcement_learning.util.helpers import ExponentialScheduler


class MAPPO(MARLAlgorithm[list[np.ndarray]]):
Expand All @@ -90,13 +90,13 @@ def __init__(
self.minibatch_size = config.minibatch_size
self.updates_per_iteration = config.updates_per_iteration

self.epsilon_scheduler = EpsilonScheduler(
start_epsilon=config.entropy_start,
end_epsilon=config.entropy_end,
self.entropy_scheduler = ExponentialScheduler(
start_value=config.entropy_start,
end_value=config.entropy_end,
decay_steps=config.entropy_decay,
)
# initial entropy coefficient
self.entropy_coef = self.epsilon_scheduler.get_epsilon(0)
self.entropy_coef = self.entropy_scheduler.get_value(0)
Comment on lines +93 to +99
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable naming is inconsistent. The local variable is called epsilon_scheduler but it's now a LinearScheduler for entropy coefficients in MAPPO. Consider renaming it to entropy_scheduler for clarity, as the entropy coefficient is not epsilon in the traditional sense.

Copilot uses AI. Check for mistakes.

self.target_kl = config.target_kl

Expand Down Expand Up @@ -150,7 +150,7 @@ def train(

info: dict[str, Any] = {}

self.entropy_coef = self.epsilon_scheduler.get_epsilon(
self.entropy_coef = self.entropy_scheduler.get_value(
episode_context.training_step
)

Expand Down
17 changes: 12 additions & 5 deletions cares_reinforcement_learning/algorithm/policy/MATD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(

self.policy_update_freq = config.policy_update_freq

self.policy_noise = config.policy_noise
self.policy_noise = config.policy_noise_end
self.policy_noise_clip = config.policy_noise_clip
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MATD3 algorithm duplicate assignment on line 81. Line 74 sets self.policy_noise_clip = config.policy_noise_clip, and then line 81 sets it again. Remove the duplicate assignment on line 81.

Suggested change
self.policy_noise_clip = config.policy_noise_clip

Copilot uses AI. Check for mistakes.

self.max_grad_norm = config.max_grad_norm
Comment on lines 71 to 83
Copy link

Copilot AI Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MATD3 updates per-agent policy_noise/action_noise via each agent’s schedulers, but target policy smoothing uses MATD3.self.policy_noise (see train(): noise = randn_like(...) * self.policy_noise). self.policy_noise is initialized from config.policy_noise_end and is never updated, so smoothing noise won’t follow the configured schedule (and will start at the end value). Consider either (1) adding a scheduler in MATD3 and updating self.policy_noise each training_step (mirroring TD3), or (2) using a value derived from the agents’ policy_noise (e.g., agent.policy_noise) so smoothing noise matches the decayed setting.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -183,10 +183,7 @@ def _update_actor(
joint_actions_flat = actions_all.reshape(batch_size, -1)
q_val, _ = agent.critic_net(global_states, joint_actions_flat)

# regularization as in TF code
reg = (pred_action_i**2).mean() * 1e-3

actor_loss = -q_val.mean() + reg
actor_loss = -q_val.mean()

# ---------------------------------------------------------
# Step 5: Backprop
Expand All @@ -213,6 +210,16 @@ def train(

info: dict[str, Any] = {}

# Update action/policy noise for exploration (decayed over training)
for current_agent in self.agent_networks:
current_agent.policy_noise = current_agent.policy_noise_scheduler.get_value(
episode_context.training_step
)

current_agent.action_noise = current_agent.action_noise_scheduler.get_value(
episode_context.training_step
)

# ---------------------------------------------------------
# Sample ONCE for all agents (recommended for TD3/SAC)
# Shared minibatch: We draw one minibatch per training iteration and reuse it across agent updates.
Expand Down
30 changes: 19 additions & 11 deletions cares_reinforcement_learning/algorithm/policy/NaSATD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
SARLObservationTensors,
)
from cares_reinforcement_learning.util.configurations import NaSATD3Config
from cares_reinforcement_learning.util.helpers import ExponentialScheduler


class NaSATD3(SARLAlgorithm[np.ndarray]):
Expand All @@ -128,16 +129,21 @@ def __init__(
self.policy_update_freq = config.policy_update_freq

# Policy noise
self.min_policy_noise = config.min_policy_noise
self.policy_noise = config.policy_noise
self.policy_noise_decay = config.policy_noise_decay

self.policy_noise_clip = config.policy_noise_clip
self.policy_noise_scheduler = ExponentialScheduler(
start_value=config.policy_noise_start,
end_value=config.policy_noise_end,
decay_steps=config.policy_noise_decay,
)
self.policy_noise = self.policy_noise_scheduler.get_value(0)

# Action noise
self.min_action_noise = config.min_action_noise
self.action_noise = config.action_noise
self.action_noise_decay = config.action_noise_decay
self.action_noise_scheduler = ExponentialScheduler(
start_value=config.action_noise_start,
end_value=config.action_noise_end,
decay_steps=config.action_noise_decay,
)
self.action_noise = self.action_noise_scheduler.get_value(0)

# Doesn't matter which autoencoder is used, as long as it is the same for all networks
self.autoencoder: VanillaAutoencoder | BurgessAutoencoder = (
Expand Down Expand Up @@ -333,11 +339,13 @@ def train(

self.learn_counter += 1

self.policy_noise *= self.policy_noise_decay
self.policy_noise = max(self.min_policy_noise, self.policy_noise)
self.policy_noise = self.policy_noise_scheduler.get_value(
episode_context.training_step
)

self.action_noise *= self.action_noise_decay
self.action_noise = max(self.min_action_noise, self.action_noise)
self.action_noise = self.action_noise_scheduler.get_value(
episode_context.training_step
)

# Convert to tensors using multimodal batch conversion
(
Expand Down
Loading