Skip to content

Commit a47ace9

Browse files
committed
removed instance comparisons in stochastic and exploration policies
1 parent c4a6d81 commit a47ace9

File tree

3 files changed

+21
-23
lines changed

3 files changed

+21
-23
lines changed

mighty/mighty_agents/sac.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _initialize_agent(self) -> None:
145145

146146
# Exploration policy wrapper
147147
self.policy = self.policy_class(
148-
algo=self, model=self.model, **self.policy_kwargs
148+
algo="sac", model=self.model, **self.policy_kwargs
149149
)
150150

151151
# Updater
@@ -207,11 +207,13 @@ def process_transition(
207207
# Ensure metrics dict
208208
if metrics is None:
209209
metrics = {}
210-
211-
# Pack transition
210+
211+
# Pack transition
212212
terminated = metrics["transition"]["terminated"] # physics‐failures
213-
transition = TransitionBatch(curr_s, action, reward, next_s, terminated.astype(int))
214-
213+
transition = TransitionBatch(
214+
curr_s, action, reward, next_s, terminated.astype(int)
215+
)
216+
215217
# Compute per-transition TD errors for logging
216218
td1, td2 = self.update_fn.calculate_td_error(transition)
217219
metrics["td_error1"] = td1.detach().cpu().numpy()

mighty/mighty_exploration/mighty_exploration_policy.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212

1313

1414
def sample_nondeterministic_logprobs(
15-
z: torch.Tensor,
16-
mean: torch.Tensor,
17-
log_std: torch.Tensor,
18-
sac: bool = False
15+
z: torch.Tensor, mean: torch.Tensor, log_std: torch.Tensor, sac: bool = False
1916
) -> torch.Tensor:
2017
"""
2118
Compute log-prob of a Gaussian sample z ~ N(mean, exp(log_std)),
@@ -115,9 +112,9 @@ def sample_func_logits(self, state_array):
115112

116113
# ─── Continuous squashed‐Gaussian (4‐tuple) ──────────────────────────
117114
elif isinstance(out, tuple) and len(out) == 4:
118-
action = out[0] # [batch, action_dim]
115+
action = out[0] # [batch, action_dim]
119116
log_prob = sample_nondeterministic_logprobs(
120-
z=out[1], mean=out[2], log_std=out[3], sac= self.ago == "sac"
117+
z=out[1], mean=out[2], log_std=out[3], sac=self.ago == "sac"
121118
)
122119
return action.detach().cpu().numpy(), log_prob
123120

mighty/mighty_exploration/stochastic_policy.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@ def __init__(
2727
:param entropy_coefficient: weight on entropy term
2828
:param discrete: whether the action space is discrete
2929
"""
30-
30+
3131
self.model = model
32-
32+
3333
super().__init__(algo, model, discrete)
3434
self.entropy_coefficient = entropy_coefficient
3535
self.discrete = discrete
36-
3736

3837
# --- override sample_action only for continuous SAC ---
3938
if not discrete and isinstance(model, SACModel):
@@ -88,9 +87,9 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
8887
# 4-tuple case (Tanh squashing): (action, z, mean, log_std)
8988
elif isinstance(model_output, tuple) and len(model_output) == 4:
9089
action, z, mean, log_std = model_output
91-
92-
if not isinstance(self.model, SACModel):
93-
90+
91+
if not self.algo == "sac":
92+
9493
log_prob = sample_nondeterministic_logprobs(
9594
z=z,
9695
mean=mean,
@@ -121,8 +120,8 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
121120
elif len(model_output) == 4:
122121
# Tanh squashing mode: (action, z, mean, log_std)
123122
action, z, mean, log_std = model_output
124-
if not isinstance(self.model, SACModel):
125-
123+
if not self.algo == "sac":
124+
126125
log_prob = sample_nondeterministic_logprobs(
127126
z=z,
128127
mean=mean,
@@ -147,7 +146,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
147146
if self.model.output_style == "squashed_gaussian":
148147
# Should be 4-tuple: (action, z, mean, log_std)
149148
action, z, mean, log_std = model_output
150-
if not isinstance(self.model, SACModel):
149+
if not self.algo == "sac":
151150
log_prob = sample_nondeterministic_logprobs(
152151
z=z,
153152
mean=mean,
@@ -170,7 +169,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
170169
z = dist.rsample()
171170
action = torch.tanh(z)
172171

173-
if not isinstance(self.model, SACModel):
172+
if not self.algo == "sac":
174173
log_prob = sample_nondeterministic_logprobs(
175174
z=z,
176175
mean=mean,
@@ -179,7 +178,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
179178
)
180179
else:
181180
log_prob = self.model.policy_log_prob(z, mean, log_std)
182-
181+
183182
entropy = dist.entropy().sum(dim=-1, keepdim=True)
184183
weighted_log_prob = log_prob * entropy
185184
return action.detach().cpu().numpy(), weighted_log_prob
@@ -190,7 +189,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
190189
)
191190

192191
# Special handling for SACModel
193-
elif isinstance(self.model, SACModel):
192+
elif self.algo == "sac" and isinstance(self.model, SACModel):
194193
action, z, mean, log_std = self.model(state, deterministic=False)
195194
# CRITICAL: Use the model's policy_log_prob which includes tanh correction
196195
log_prob = self.model.policy_log_prob(z, mean, log_std)

0 commit comments

Comments
 (0)