Skip to content

Commit c4a6d81

Browse files
committed
updates for Merge
1 parent ade9d40 commit c4a6d81

File tree

8 files changed

+161
-86
lines changed

8 files changed

+161
-86
lines changed

mighty/mighty_agents/base_agent.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ def __init__( # noqa: PLR0915, PLR0912
141141
normalize_obs: bool = False,
142142
normalize_reward: bool = False,
143143
rescale_action: bool = False,
144-
handle_timeout_termination: bool = False,
145144
):
146145
"""Base agent initialization.
147146
@@ -302,8 +301,6 @@ def __init__( # noqa: PLR0915, PLR0912
302301
for m in self.meta_modules.values():
303302
m.seed(self.seed)
304303
self.steps = 0
305-
306-
self.handle_timeout_termination = handle_timeout_termination
307304

308305
def _initialize_agent(self) -> None:
309306
"""Agent/algorithm specific initializations."""
@@ -606,24 +603,21 @@ def run( # noqa: PLR0915
606603
metrics["episode_reward"] = episode_reward
607604

608605
action, log_prob = self.step(curr_s, metrics)
609-
# 1) step the env as usual
606+
# step the env as usual
610607
next_s, reward, terminated, truncated, infos = self.env.step(action)
611608

612-
# 2) decide which samples are true “done”
609+
# decide which samples are true “done”
613610
replay_dones = terminated # physics‐failure only
614-
dones = np.logical_or(terminated, truncated)
611+
dones = np.logical_or(terminated, truncated)
615612

616613

617-
# 3) optionally overwrite next_s on truncation
618-
if self.handle_timeout_termination:
619-
real_next_s = next_s.copy()
620-
# infos["final_observation"] is a list/array of the last real obs
621-
for i, tr in enumerate(truncated):
622-
if tr:
623-
real_next_s[i] = infos["final_observation"][i]
624-
else:
625-
real_next_s = next_s
626-
614+
# Overwrite next_s on truncation
615+
# Based on https://github.com/DLR-RM/stable-baselines3/issues/284
616+
real_next_s = next_s.copy()
617+
# infos["final_observation"] is a list/array of the last real obs
618+
for i, tr in enumerate(truncated):
619+
if tr:
620+
real_next_s[i] = infos["final_observation"][i]
627621
episode_reward += reward
628622

629623
# Log everything

mighty/mighty_agents/dqn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(
6969
normalize_obs: bool = False,
7070
normalize_reward: bool = False,
7171
rescale_action: bool = False, # type: ignore
72-
handle_timeout_termination: bool = False,
7372
):
7473
"""DQN initialization.
7574
@@ -155,7 +154,6 @@ def __init__(
155154
normalize_obs=normalize_obs,
156155
normalize_reward=normalize_reward,
157156
rescale_action=rescale_action,
158-
handle_timeout_termination=handle_timeout_termination
159157
)
160158

161159
self.loss_buffer = {

mighty/mighty_agents/ppo.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def __init__(
6262
normalize_reward: bool = False,
6363
rescale_action: bool = False,
6464
tanh_squash: bool = False,
65-
handle_timeout_termination: bool = False,
6665
):
6766
"""Initialize the PPO agent.
6867
@@ -144,7 +143,6 @@ def __init__(
144143
normalize_obs=normalize_obs,
145144
normalize_reward=normalize_reward,
146145
rescale_action=rescale_action,
147-
handle_timeout_termination=handle_timeout_termination
148146
)
149147

150148
self.loss_buffer = {

mighty/mighty_agents/sac.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(
5757
rescale_action: bool = False, # ← NEW Whether to rescale actions to the environment's action space
5858
policy_frequency: int = 2, # Frequency of policy updates
5959
target_network_frequency: int = 1, # Frequency of target network updates
60-
handle_timeout_termination: bool = True,
6160
):
6261
"""Initialize SAC agent with tunable hyperparameters and backward-compatible names."""
6362
if hidden_sizes is None:
@@ -117,7 +116,6 @@ def __init__(
117116
rescale_action=rescale_action,
118117
batch_size=batch_size,
119118
learning_rate=policy_lr, # For compatibility with base class
120-
handle_timeout_termination=handle_timeout_termination,
121119
)
122120

123121
# Initialize loss buffer for logging
@@ -209,9 +207,8 @@ def process_transition(
209207
# Ensure metrics dict
210208
if metrics is None:
211209
metrics = {}
210+
212211
# Pack transition
213-
# `terminated` is used for physics failures in environments like `MightyEnv`
214-
# Based on https://github.com/DLR-RM/stable-baselines3/issues/284
215212
terminated = metrics["transition"]["terminated"] # physics‐failures
216213
transition = TransitionBatch(curr_s, action, reward, next_s, terminated.astype(int))
217214

mighty/mighty_exploration/mighty_exploration_policy.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,9 @@ def sample_func_logits(self, state_array):
115115

116116
# ─── Continuous squashed‐Gaussian (4‐tuple) ──────────────────────────
117117
elif isinstance(out, tuple) and len(out) == 4:
118-
action = out[0] # [batch, action_dim]
119-
120-
print(f'Self Model : {self.model}')
118+
action = out[0] # [batch, action_dim]
121119
log_prob = sample_nondeterministic_logprobs(
122-
z=out[1], mean=out[2], log_std=out[3], sac=isinstance(self.model, SACModel)
120+
z=out[1], mean=out[2], log_std=out[3], sac= self.ago == "sac"
123121
)
124122
return action.detach().cpu().numpy(), log_prob
125123

mighty/mighty_exploration/stochastic_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
103103
if return_logp:
104104
return action.detach().cpu().numpy(), log_prob
105105
else:
106-
weighted_log_prob = log_prob
106+
weighted_log_prob = log_prob * self.entropy_coefficient
107107
return action.detach().cpu().numpy(), weighted_log_prob
108108

109109
# Check for model attribute-based approaches

mighty/mighty_models/sac.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
# This model is continuous only
3333
self.continuous_action = True
3434

35-
# PR: register the per-dim scale and bias so we can rescale [-1,1]→[low,high].
35+
# Register the per-dim scale and bias so we can rescale [-1,1]→[low,high].
3636
action_low = torch.as_tensor(action_low, dtype=torch.float32)
3737
action_high = torch.as_tensor(action_high, dtype=torch.float32)
3838
self.register_buffer(
@@ -67,42 +67,75 @@ def __init__(
6767
self.hidden_sizes = feature_extractor_kwargs.get("hidden_sizes", [256, 256])
6868
self.activation = feature_extractor_kwargs.get("activation", "relu")
6969

70-
# Shared feature extractor for policy
71-
self.feature_extractor, out_dim = make_feature_extractor(
70+
# Policy feature extractor and head
71+
self.policy_feature_extractor, policy_feat_dim = make_feature_extractor(
7272
**feature_extractor_kwargs
7373
)
74-
75-
# Policy network outputs mean and log_std
76-
# CHANGE: Create separate policy network (actor) similar to CleanRL
77-
self.policy_net = make_policy_head(
78-
in_size=self.obs_size,
74+
75+
# Policy head: just the final output layer
76+
self.policy_head = make_policy_head(
77+
in_size=policy_feat_dim,
7978
out_size=self.action_size * 2, # mean and log_std
80-
**head_kwargs
79+
hidden_sizes=[], # No hidden layers, just final linear layer
80+
activation=head_kwargs["activation"]
8181
)
8282

83-
# Twin Q-networks
84-
# — live Q-nets —
85-
self.q_net1 = make_q_head(
86-
in_size=self.obs_size + self.action_size, **head_kwargs
83+
# Create policy_net for backward compatibility
84+
self.policy_net = nn.Sequential(self.policy_feature_extractor, self.policy_head)
85+
86+
# Q-networks: feature extractors + heads
87+
q_feature_extractor_kwargs = feature_extractor_kwargs.copy()
88+
q_feature_extractor_kwargs["obs_shape"] = self.obs_size + self.action_size
89+
90+
# Q-network 1
91+
self.q_feature_extractor1, q_feat_dim = make_feature_extractor(**q_feature_extractor_kwargs)
92+
self.q_head1 = make_q_head(
93+
in_size=q_feat_dim,
94+
hidden_sizes=[], # No hidden layers, just final linear layer
95+
activation=head_kwargs["activation"]
8796
)
88-
self.q_net2 = make_q_head(
89-
in_size=self.obs_size + self.action_size, **head_kwargs
97+
self.q_net1 = nn.Sequential(self.q_feature_extractor1, self.q_head1)
98+
99+
# Q-network 2
100+
self.q_feature_extractor2, _ = make_feature_extractor(**q_feature_extractor_kwargs)
101+
self.q_head2 = make_q_head(
102+
in_size=q_feat_dim,
103+
hidden_sizes=[], # No hidden layers, just final linear layer
104+
activation=head_kwargs["activation"]
90105
)
106+
self.q_net2 = nn.Sequential(self.q_feature_extractor2, self.q_head2)
91107

92108
# Target Q-networks
93-
self.target_q_net1 = make_q_head(
94-
in_size=self.obs_size + self.action_size, **head_kwargs
109+
self.target_q_feature_extractor1, _ = make_feature_extractor(**q_feature_extractor_kwargs)
110+
self.target_q_head1 = make_q_head(
111+
in_size=q_feat_dim,
112+
hidden_sizes=[], # No hidden layers, just final linear layer
113+
activation=head_kwargs["activation"]
95114
)
96-
self.target_q_net1.load_state_dict(self.q_net1.state_dict())
97-
self.target_q_net2 = make_q_head(
98-
in_size=self.obs_size + self.action_size, **head_kwargs
115+
self.target_q_net1 = nn.Sequential(self.target_q_feature_extractor1, self.target_q_head1)
116+
117+
self.target_q_feature_extractor2, _ = make_feature_extractor(**q_feature_extractor_kwargs)
118+
self.target_q_head2 = make_q_head(
119+
in_size=q_feat_dim,
120+
hidden_sizes=[], # No hidden layers, just final linear layer
121+
activation=head_kwargs["activation"]
99122
)
100-
self.target_q_net2.load_state_dict(self.q_net2.state_dict())
123+
self.target_q_net2 = nn.Sequential(self.target_q_feature_extractor2, self.target_q_head2)
124+
125+
# Copy weights from live to target networks
126+
self.target_q_feature_extractor1.load_state_dict(self.q_feature_extractor1.state_dict())
127+
self.target_q_head1.load_state_dict(self.q_head1.state_dict())
128+
self.target_q_feature_extractor2.load_state_dict(self.q_feature_extractor2.state_dict())
129+
self.target_q_head2.load_state_dict(self.q_head2.state_dict())
101130

102131
# Freeze target networks
103-
for p in self.target_q_net1.parameters():
132+
for p in self.target_q_feature_extractor1.parameters():
133+
p.requires_grad = False
134+
for p in self.target_q_head1.parameters():
135+
p.requires_grad = False
136+
for p in self.target_q_feature_extractor2.parameters():
104137
p.requires_grad = False
105-
for p in self.target_q_net2.parameters():
138+
for p in self.target_q_head2.parameters():
106139
p.requires_grad = False
107140

108141
# Create a value function wrapper for compatibility
@@ -133,7 +166,7 @@ def forward(
133166
Forward pass for policy sampling.
134167
135168
Returns:
136-
action: torch.Tensor in [-1,1]
169+
action: torch.Tensor in rescaled range [action_low, action_high]
137170
z: raw Gaussian sample before tanh
138171
mean: Gaussian mean
139172
log_std: Gaussian log std
@@ -155,7 +188,7 @@ def forward(
155188
# tanh→[-1,1]
156189
raw_action = torch.tanh(z)
157190

158-
# **HERE** we rescale into [low,high]
191+
# Rescale into [action_low, action_high]
159192
action = raw_action * self.action_scale + self.action_bias
160193

161194
return action, z, mean, log_std

0 commit comments

Comments
 (0)