Skip to content

Commit 06cdf39

Browse files
author
Ervin T
authored
Cherry-pick BC fixes to Release 10 (#4668)
1 parent b0ac32e commit 06cdf39

File tree

5 files changed

+33
-15
lines changed

5 files changed

+33
-15
lines changed

ml-agents/mlagents/trainers/policy/torch_policy.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,16 @@ def sample_actions(
124124
memories: Optional[torch.Tensor] = None,
125125
seq_len: int = 1,
126126
all_log_probs: bool = False,
127-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
127+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
128128
"""
129129
:param vec_obs: List of vector observations.
130130
:param vis_obs: List of visual observations.
131131
:param masks: Loss masks for RNN, else None.
132132
:param memories: Input memories when using RNN, else None.
133133
:param seq_len: Sequence length when using RNN.
134134
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
135-
:return: Tuple of actions, log probabilities (dependent on all_log_probs), entropies, and
136-
output memories, all as Torch Tensors.
135+
:return: Tuple of actions, actions clipped to -1, 1, log probabilities (dependent on all_log_probs),
136+
entropies, and output memories, all as Torch Tensors.
137137
"""
138138
if memories is None:
139139
dists, memories = self.actor_critic.get_dists(
@@ -155,8 +155,14 @@ def sample_actions(
155155
actions = actions[:, 0, :]
156156
# Use the sum of entropy across actions, not the mean
157157
entropy_sum = torch.sum(entropies, dim=1)
158+
159+
if self._clip_action and self.use_continuous_act:
160+
clipped_action = torch.clamp(actions, -3, 3) / 3
161+
else:
162+
clipped_action = actions
158163
return (
159164
actions,
165+
clipped_action,
160166
all_logs if all_log_probs else log_probs,
161167
entropy_sum,
162168
memories,
@@ -201,14 +207,10 @@ def evaluate(
201207

202208
run_out = {}
203209
with torch.no_grad():
204-
action, log_probs, entropy, memories = self.sample_actions(
210+
action, clipped_action, log_probs, entropy, memories = self.sample_actions(
205211
vec_obs, vis_obs, masks=masks, memories=memories
206212
)
207213

208-
if self._clip_action and self.use_continuous_act:
209-
clipped_action = torch.clamp(action, -3, 3) / 3
210-
else:
211-
clipped_action = action
212214
run_out["pre_action"] = ModelUtils.to_numpy(action)
213215
run_out["action"] = ModelUtils.to_numpy(clipped_action)
214216
# Todo - make pre_action difference

ml-agents/mlagents/trainers/sac/optimizer_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
464464
self.target_network.network_body.copy_normalization(
465465
self.policy.actor_critic.network_body
466466
)
467-
(sampled_actions, log_probs, _, _) = self.policy.sample_actions(
467+
(sampled_actions, _, log_probs, _, _) = self.policy.sample_actions(
468468
vec_obs,
469469
vis_obs,
470470
masks=act_masks,

ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None:
7979
).unsqueeze(0)
8080

8181
with torch.no_grad():
82-
_, log_probs1, _, _ = policy1.sample_actions(
82+
_, _, log_probs1, _, _ = policy1.sample_actions(
8383
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
8484
)
85-
_, log_probs2, _, _ = policy2.sample_actions(
85+
_, _, log_probs2, _, _ = policy2.sample_actions(
8686
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
8787
)
8888

ml-agents/mlagents/trainers/tests/torch/test_policy.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,13 @@ def test_sample_actions(rnn, visual, discrete):
126126
if len(memories) > 0:
127127
memories = torch.stack(memories).unsqueeze(0)
128128

129-
(sampled_actions, log_probs, entropies, memories) = policy.sample_actions(
129+
(
130+
sampled_actions,
131+
clipped_actions,
132+
log_probs,
133+
entropies,
134+
memories,
135+
) = policy.sample_actions(
130136
vec_obs,
131137
vis_obs,
132138
masks=act_masks,
@@ -141,6 +147,10 @@ def test_sample_actions(rnn, visual, discrete):
141147
)
142148
else:
143149
assert log_probs.shape == (64, policy.behavior_spec.action_spec.continuous_size)
150+
assert clipped_actions.shape == (
151+
64,
152+
policy.behavior_spec.action_spec.continuous_size,
153+
)
144154
assert entropies.shape == (64,)
145155

146156
if rnn:

ml-agents/mlagents/trainers/torch/components/bc/module.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def update(self) -> Dict[str, np.ndarray]:
6262
# Don't continue training if the learning rate has reached 0, to reduce training time.
6363

6464
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
65-
if self.current_lr <= 0:
65+
if self.current_lr <= 1e-10: # Unlike in TF, this never actually reaches 0.
6666
return {"Losses/Pretraining Loss": 0}
6767

6868
batch_losses = []
@@ -164,7 +164,13 @@ def _update_batch(
164164
else:
165165
vis_obs = []
166166

167-
selected_actions, all_log_probs, _, _ = self.policy.sample_actions(
167+
(
168+
selected_actions,
169+
clipped_actions,
170+
all_log_probs,
171+
_,
172+
_,
173+
) = self.policy.sample_actions(
168174
vec_obs,
169175
vis_obs,
170176
masks=act_masks,
@@ -173,7 +179,7 @@ def _update_batch(
173179
all_log_probs=True,
174180
)
175181
bc_loss = self._behavioral_cloning_loss(
176-
selected_actions, all_log_probs, expert_actions
182+
clipped_actions, all_log_probs, expert_actions
177183
)
178184
self.optimizer.zero_grad()
179185
bc_loss.backward()

0 commit comments

Comments
 (0)