Skip to content

Commit 9edb039

Browse files
authored
Merge pull request #96 from automl/sac_fix
Sac fix
2 parents 9d283a0 + a47ace9 commit 9edb039

File tree

9 files changed

+372
-149
lines changed

9 files changed

+372
-149
lines changed

mighty/configs/algorithm/sac.yaml

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,50 @@ algorithm_kwargs:
77
# Normalization
88
normalize_obs: False
99
normalize_reward: False
10+
rescale_action: True # CRITICAL: Add this! Must be True for MuJoCo
1011

1112
# Network sizes
12-
n_policy_units: 256
13-
soft_update_weight: 0.005
13+
n_policy_units: 256
14+
soft_update_weight: 0.005 # tau in SAC terms
1415

1516
# Replay buffer
1617
replay_buffer_class:
1718
_target_: mighty.mighty_replay.MightyReplay
1819
replay_buffer_kwargs:
1920
capacity: 1e6
2021

22+
2123
# Scheduling & batch-updates
22-
batch_size: 256
23-
learning_starts: 5000
24-
update_every: 1
25-
n_gradient_steps: 1
24+
batch_size: 256
25+
learning_starts: 5000 # Good, matches CleanRL
26+
update_every: 1 # Good, update every step
27+
n_gradient_steps: 1 # Good
2628

2729
# Learning rates
2830
policy_lr: 3e-4
29-
q_lr: 1e-3
30-
alpha_lr: 1e-3
31+
q_lr: 1e-3 # This is correct now (was 3e-4)
32+
alpha_lr: 3e-4 # 3e-4 is better than 1e-3 for alpha
3133

3234
# SAC hyperparameters
3335
gamma: 0.99
3436
alpha: 0.2
3537
auto_alpha: True
36-
target_entropy: -6.0 # -action_dim for HalfCheetah (6 actions)
38+
target_entropy: null # Let it auto-compute as -action_dim
39+
40+
# Network architecture
41+
hidden_sizes: [256, 256] # Explicitly specify
42+
activation: relu
43+
log_std_min: -5
44+
log_std_max: 2
3745

3846
# Policy configuration
3947
policy_class: mighty.mighty_exploration.StochasticPolicy
4048
policy_kwargs:
41-
entropy_coefficient: 0.0
4249
discrete: False
43-
50+
# Remove entropy_coefficient - SAC handles alpha internally
4451

4552
# SAC specific frequencies
46-
policy_frequency: 2 # Delayed policy updates
53+
policy_frequency: 2 # Can also try 1 for even better performance
4754
target_network_frequency: 1 # Update targets every step
4855

4956
# Environment and training configuration
@@ -55,5 +62,5 @@ max_episode_steps: 1000 # HalfCheetah episode length
5562
eval_frequency: 10000 # More frequent eval for single env
5663
save_frequency: 50000 # Save every 50k steps
5764

58-
59-
# python mighty/run_mighty.py algorithm=sac env=HalfCheetah-v4 num_steps=1e6 num_envs=1
65+
# Command to run:
66+
# python mighty/run_mighty.py algorithm=sac env=HalfCheetah-v4 num_steps=1e6 num_envs=1

mighty/mighty_agents/base_agent.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,21 @@ def run( # noqa: PLR0915
603603
metrics["episode_reward"] = episode_reward
604604

605605
action, log_prob = self.step(curr_s, metrics)
606-
next_s, reward, terminated, truncated, _ = self.env.step(action) # type: ignore
607-
dones = np.logical_or(terminated, truncated)
606+
# step the env as usual
607+
next_s, reward, terminated, truncated, infos = self.env.step(action)
608608

609+
# decide which samples are true “done”
610+
replay_dones = terminated # physics‐failure only
611+
dones = np.logical_or(terminated, truncated)
612+
613+
614+
# Overwrite next_s on truncation
615+
# Based on https://github.com/DLR-RM/stable-baselines3/issues/284
616+
real_next_s = next_s.copy()
617+
# infos["final_observation"] is a list/array of the last real obs
618+
for i, tr in enumerate(truncated):
619+
if tr:
620+
real_next_s[i] = infos["final_observation"][i]
609621
episode_reward += reward
610622

611623
# Log everything
@@ -615,10 +627,10 @@ def run( # noqa: PLR0915
615627
"reward": reward,
616628
"action": action,
617629
"state": curr_s,
618-
"next_state": next_s,
630+
"next_state": real_next_s,
619631
"terminated": terminated.astype(int),
620632
"truncated": truncated.astype(int),
621-
"dones": dones.astype(int),
633+
"dones": replay_dones.astype(int),
622634
"mean_episode_reward": last_episode_reward.mean()
623635
.cpu()
624636
.numpy()

mighty/mighty_agents/sac.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
# --- Network architecture (optional override) ---
3939
hidden_sizes: Optional[List[int]] = None,
4040
activation: str = "relu",
41-
log_std_min: float = -20,
41+
log_std_min: float = -5,
4242
log_std_max: float = 2,
4343
# --- Logging & buffer ---
4444
render_progress: bool = True,
@@ -145,7 +145,7 @@ def _initialize_agent(self) -> None:
145145

146146
# Exploration policy wrapper
147147
self.policy = self.policy_class(
148-
algo=self, model=self.model, **self.policy_kwargs
148+
algo="sac", model=self.model, **self.policy_kwargs
149149
)
150150

151151
# Updater
@@ -207,8 +207,13 @@ def process_transition(
207207
# Ensure metrics dict
208208
if metrics is None:
209209
metrics = {}
210+
210211
# Pack transition
211-
transition = TransitionBatch(curr_s, action, reward, next_s, dones)
212+
terminated = metrics["transition"]["terminated"] # physics‐failures
213+
transition = TransitionBatch(
214+
curr_s, action, reward, next_s, terminated.astype(int)
215+
)
216+
212217
# Compute per-transition TD errors for logging
213218
td1, td2 = self.update_fn.calculate_td_error(transition)
214219
metrics["td_error1"] = td1.detach().cpu().numpy()

mighty/mighty_exploration/mighty_exploration_policy.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,31 @@
88
import torch
99
from torch.distributions import Categorical, Normal
1010

11+
from mighty.mighty_models import SACModel
12+
1113

1214
def sample_nondeterministic_logprobs(
1315
z: torch.Tensor, mean: torch.Tensor, log_std: torch.Tensor, sac: bool = False
14-
) -> Tuple[torch.Tensor, torch.Tensor]:
16+
) -> torch.Tensor:
17+
"""
18+
Compute log-prob of a Gaussian sample z ~ N(mean, exp(log_std)),
19+
and if sac=True apply the tanh-squash correction to get log π(a).
20+
"""
1521
std = torch.exp(log_std) # [batch, action_dim]
1622
dist = Normal(mean, std)
23+
# base Gaussian log‐prob of z
24+
log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1]
1725

18-
# For SAC, don't apply correction
1926
if sac:
20-
return dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1]
21-
# If not SAC, we need to apply the tanh correction
22-
else:
23-
log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1]
24-
25-
# 2b) tanh‐correction = ∑ᵢ log(1 − tanh(zᵢ)² + ε)
26-
eps = 1e-6
27+
# subtract the ∑_i log(d tanh/dz_i) = ∑ log(1 - tanh(z)^2)
28+
eps = 1e-4
2729
log_correction = torch.log(1.0 - torch.tanh(z).pow(2) + eps).sum(
2830
dim=-1, keepdim=True
2931
) # [batch, 1]
30-
31-
# 2c) final log_prob of a = tanh(z)
32-
log_prob = log_pz - log_correction # [batch, 1]
33-
return log_prob
32+
return log_pz - log_correction
33+
else:
34+
# PPO-style or other: no squash correction
35+
return log_pz
3436

3537

3638
class MightyExplorationPolicy:
@@ -112,7 +114,7 @@ def sample_func_logits(self, state_array):
112114
elif isinstance(out, tuple) and len(out) == 4:
113115
action = out[0] # [batch, action_dim]
114116
log_prob = sample_nondeterministic_logprobs(
115-
z=out[1], mean=out[2], log_std=out[3], sac=self.algo == "sac"
117+
z=out[1], mean=out[2], log_std=out[3], sac=self.ago == "sac"
116118
)
117119
return action.detach().cpu().numpy(), log_prob
118120

mighty/mighty_exploration/stochastic_policy.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def __init__(
2727
:param entropy_coefficient: weight on entropy term
2828
:param discrete: whether the action space is discrete
2929
"""
30+
31+
self.model = model
32+
3033
super().__init__(algo, model, discrete)
3134
self.entropy_coefficient = entropy_coefficient
3235
self.discrete = discrete
@@ -84,33 +87,24 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
8487
# 4-tuple case (Tanh squashing): (action, z, mean, log_std)
8588
elif isinstance(model_output, tuple) and len(model_output) == 4:
8689
action, z, mean, log_std = model_output
87-
log_prob = sample_nondeterministic_logprobs(
88-
z=z,
89-
mean=mean,
90-
log_std=log_std,
91-
sac=self.algo == "sac",
92-
)
90+
91+
if not self.algo == "sac":
92+
93+
log_prob = sample_nondeterministic_logprobs(
94+
z=z,
95+
mean=mean,
96+
log_std=log_std,
97+
sac=False,
98+
)
99+
else:
100+
log_prob = self.model.policy_log_prob(z, mean, log_std)
93101

94102
if return_logp:
95103
return action.detach().cpu().numpy(), log_prob
96104
else:
97105
weighted_log_prob = log_prob * self.entropy_coefficient
98106
return action.detach().cpu().numpy(), weighted_log_prob
99107

100-
# Legacy 2-tuple case: (mean, std)
101-
elif isinstance(model_output, tuple) and len(model_output) == 2:
102-
mean, std = model_output
103-
dist = Normal(mean, std)
104-
z = dist.rsample() # [batch, action_dim]
105-
action = torch.tanh(z) # [batch, action_dim]
106-
107-
log_prob = sample_nondeterministic_logprobs(
108-
z=z, mean=mean, log_std=torch.log(std), sac=self.algo == "sac"
109-
)
110-
entropy = dist.entropy().sum(dim=-1, keepdim=True) # [batch, 1]
111-
weighted_log_prob = log_prob * entropy
112-
return action.detach().cpu().numpy(), weighted_log_prob
113-
114108
# Check for model attribute-based approaches
115109
elif hasattr(self.model, "continuous_action") and getattr(
116110
self.model, "continuous_action"
@@ -126,9 +120,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
126120
elif len(model_output) == 4:
127121
# Tanh squashing mode: (action, z, mean, log_std)
128122
action, z, mean, log_std = model_output
129-
log_prob = sample_nondeterministic_logprobs(
130-
z=z, mean=mean, log_std=log_std, sac=self.algo == "sac"
131-
)
123+
if not self.algo == "sac":
124+
125+
log_prob = sample_nondeterministic_logprobs(
126+
z=z,
127+
mean=mean,
128+
log_std=log_std,
129+
sac=False,
130+
)
131+
else:
132+
log_prob = self.model.policy_log_prob(z, mean, log_std)
132133
else:
133134
raise ValueError(
134135
f"Unexpected model output length: {len(model_output)}"
@@ -145,9 +146,15 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
145146
if self.model.output_style == "squashed_gaussian":
146147
# Should be 4-tuple: (action, z, mean, log_std)
147148
action, z, mean, log_std = model_output
148-
log_prob = sample_nondeterministic_logprobs(
149-
z=z, mean=mean, log_std=log_std, sac=self.algo == "sac"
150-
)
149+
if not self.algo == "sac":
150+
log_prob = sample_nondeterministic_logprobs(
151+
z=z,
152+
mean=mean,
153+
log_std=log_std,
154+
sac=False,
155+
)
156+
else:
157+
log_prob = self.model.policy_log_prob(z, mean, log_std)
151158

152159
if return_logp:
153160
return action.detach().cpu().numpy(), log_prob
@@ -162,9 +169,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
162169
z = dist.rsample()
163170
action = torch.tanh(z)
164171

165-
log_prob = sample_nondeterministic_logprobs(
166-
z=z, mean=mean, log_std=torch.log(std), sac=self.algo == "sac"
167-
)
172+
if not self.algo == "sac":
173+
log_prob = sample_nondeterministic_logprobs(
174+
z=z,
175+
mean=mean,
176+
log_std=log_std,
177+
sac=False,
178+
)
179+
else:
180+
log_prob = self.model.policy_log_prob(z, mean, log_std)
181+
168182
entropy = dist.entropy().sum(dim=-1, keepdim=True)
169183
weighted_log_prob = log_prob * entropy
170184
return action.detach().cpu().numpy(), weighted_log_prob
@@ -175,14 +189,11 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
175189
)
176190

177191
# Special handling for SACModel
178-
elif isinstance(self.model, SACModel):
192+
elif self.algo == "sac" and isinstance(self.model, SACModel):
179193
action, z, mean, log_std = self.model(state, deterministic=False)
180-
std = torch.exp(log_std)
181-
dist = Normal(mean, std)
182-
183-
log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True)
184-
weighted_log_prob = log_pz * self.entropy_coefficient
185-
return action.detach().cpu().numpy(), weighted_log_prob
194+
# CRITICAL: Use the model's policy_log_prob which includes tanh correction
195+
log_prob = self.model.policy_log_prob(z, mean, log_std)
196+
return action.detach().cpu().numpy(), log_prob
186197

187198
else:
188199
raise RuntimeError(

0 commit comments

Comments
 (0)