Skip to content

Commit c8f45bf

Browse files
Ruo-Ping DongErvin T.
andauthored
Fix model inference issue with Barracuda v1.2.1 (#4766) (#4768)
Co-authored-by: Ervin T. <[email protected]>
1 parent bf79a69 commit c8f45bf

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

ml-agents/mlagents/trainers/torch/distributions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,11 @@ def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
173173
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
174174
else:
175175
# Expand so that entropy matches batch size. Note that we're using
176-
# torch.cat here instead of torch.expand() becuase it is not supported in the
177-
# verified version of Barracuda (1.0.2).
178-
log_sigma = torch.cat([self.log_sigma] * inputs.shape[0], axis=0)
176+
# mu*0 here to get the batch size implicitly since Barracuda 1.2.1
177+
# throws error on runtime broadcasting due to unknown reason. We
178+
# use this to replace torch.expand() becuase it is not supported in
179+
# the verified version of Barracuda (1.0.X).
180+
log_sigma = mu * 0 + self.log_sigma
179181
if self.tanh_squash:
180182
return TanhGaussianDistInstance(mu, torch.exp(log_sigma))
181183
else:

ml-agents/mlagents/trainers/torch/networks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,11 @@ def __init__(
258258
):
259259
super().__init__()
260260
self.action_spec = action_spec
261-
self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
261+
self.version_number = torch.nn.Parameter(
262+
torch.Tensor([2.0]), requires_grad=False
263+
)
262264
self.is_continuous_int_deprecated = torch.nn.Parameter(
263-
torch.Tensor([int(self.action_spec.is_continuous())])
265+
torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False
264266
)
265267
self.continuous_act_size_vector = torch.nn.Parameter(
266268
torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False
@@ -283,6 +285,9 @@ def __init__(
283285
self.encoding_size = network_settings.memory.memory_size // 2
284286
else:
285287
self.encoding_size = network_settings.hidden_units
288+
self.memory_size_vector = torch.nn.Parameter(
289+
torch.Tensor([int(self.network_body.memory_size)]), requires_grad=False
290+
)
286291

287292
self.action_model = ActionModel(
288293
self.encoding_size,
@@ -335,10 +340,7 @@ def forward(
335340
disc_action_out,
336341
action_out_deprecated,
337342
) = self.action_model.get_action_out(encoding, masks)
338-
export_out = [
339-
self.version_number,
340-
torch.Tensor([self.network_body.memory_size]),
341-
]
343+
export_out = [self.version_number, self.memory_size_vector]
342344
if self.action_spec.continuous_size > 0:
343345
export_out += [cont_action_out, self.continuous_act_size_vector]
344346
if self.action_spec.discrete_size > 0:

0 commit comments

Comments
 (0)