Skip to content

Commit 736ac08

Browse files
Cherry picked #4491 (#4493)
[Bug fix] Export all branches for discrete control torch
1 parent 2a69c1f commit 736ac08

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
66
and this project adheres to
77
[Semantic Versioning](http://semver.org/spec/v2.0.0.html).
88

9+
910
## [1.4.0-preview] - 2020-09-16
1011
### Major Changes
1112
#### com.unity.ml-agents (C#)
@@ -55,6 +56,7 @@ Academy has been shut down. (#4489)
5556
- Fixed the sample code in the custom SideChannel example. (#4466)
5657
- A bug in the observation normalizer that would cause rewards to decrease
5758
when using `--resume` was fixed. (#4463)
59+
- Fixed a bug in exporting Pytorch models when using multiple discrete actions. (#4491)
5860

5961
## [1.3.0-preview] - 2020-08-12
6062

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,9 @@ def __init__(
267267
self.is_continuous_int = torch.nn.Parameter(
268268
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
269269
)
270-
self.act_size_vector = torch.nn.Parameter(torch.Tensor(act_size))
270+
self.act_size_vector = torch.nn.Parameter(
271+
torch.Tensor([sum(act_size)]), requires_grad=False
272+
)
271273
self.network_body = NetworkBody(observation_shapes, network_settings)
272274
if network_settings.memory is not None:
273275
self.encoding_size = network_settings.memory.memory_size // 2
@@ -329,12 +331,11 @@ def forward(
329331
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
330332
"""
331333
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
332-
action_list = self.sample_action(dists)
333-
sampled_actions = torch.stack(action_list, dim=-1)
334334
if self.act_type == ActionType.CONTINUOUS:
335-
action_out = sampled_actions
335+
action_list = self.sample_action(dists)
336+
action_out = torch.stack(action_list, dim=-1)
336337
else:
337-
action_out = dists[0].all_log_prob()
338+
action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1)
338339
return (
339340
action_out,
340341
self.version_number,

0 commit comments

Comments
 (0)