Skip to content

Commit ada3bc5

Browse files
author
Ervin T
committed
Bugfix for LSTM+BC (#2679)
* Fix LSTM+BC in discrete case * Add test for Barracuda export * Fix LSTM training for BC
1 parent 657e09b commit ada3bc5

File tree

5 files changed

+56
-8
lines changed

5 files changed

+56
-8
lines changed

ml-agents/mlagents/trainers/bc/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
for size in self.act_size:
4141
policy_branches.append(
4242
tf.layers.dense(
43-
hidden,
43+
hidden_reg,
4444
size,
4545
activation=None,
4646
use_bias=False,

ml-agents/mlagents/trainers/bc/trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,12 @@ def update_policy(self):
129129
len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences,
130130
self.batches_per_epoch,
131131
)
132-
for i in range(num_batches):
132+
133+
batch_size = self.n_sequences * self.policy.sequence_length
134+
135+
for i in range(0, num_batches * batch_size, batch_size):
133136
update_buffer = self.demonstration_buffer.update_buffer
134-
start = i * self.n_sequences
135-
end = (i + 1) * self.n_sequences
136-
mini_batch = update_buffer.make_mini_batch(start, end)
137+
mini_batch = update_buffer.make_mini_batch(i, i + batch_size)
137138
run_out = self.policy.update(mini_batch, self.n_sequences)
138139
loss = run_out["policy_loss"]
139140
batch_losses.append(loss)

ml-agents/mlagents/trainers/tests/mock_brain.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,16 @@ def create_mock_3dball_brain():
212212
return mock_brain
213213

214214

215+
def create_mock_pushblock_brain():
216+
mock_brain = create_mock_brainparams(
217+
vector_action_space_type="discrete",
218+
vector_action_space_size=[7],
219+
vector_observation_space_size=70,
220+
)
221+
mock_brain.brain_name = "PushblockLearning"
222+
return mock_brain
223+
224+
215225
def create_mock_banana_brain():
216226
mock_brain = create_mock_brainparams(
217227
number_visual_observations=1,

ml-agents/mlagents/trainers/tests/test_barracuda_converter.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import os
2+
import yaml
3+
import pytest
24
import tempfile
35

46
import mlagents.trainers.tensorflow_to_barracuda as tf2bc
7+
from mlagents.trainers.tests.test_bc import create_bc_trainer
58

69

710
def test_barracuda_converter():
@@ -24,3 +27,29 @@ def test_barracuda_converter():
2427

2528
# cleanup
2629
os.remove(tmpfile)
30+
31+
32+
@pytest.fixture
33+
def bc_dummy_config():
34+
return yaml.safe_load(
35+
"""
36+
hidden_units: 32
37+
learning_rate: 3.0e-4
38+
num_layers: 1
39+
use_recurrent: false
40+
sequence_length: 32
41+
memory_size: 64
42+
batches_per_epoch: 1
43+
batch_size: 64
44+
summary_freq: 2000
45+
max_steps: 4000
46+
"""
47+
)
48+
49+
50+
@pytest.mark.parametrize("use_lstm", [False, True], ids=["nolstm", "lstm"])
51+
@pytest.mark.parametrize("use_discrete", [True, False], ids=["disc", "cont"])
52+
def test_bc_export(bc_dummy_config, use_lstm, use_discrete):
53+
bc_dummy_config["use_recurrent"] = use_lstm
54+
trainer, env = create_bc_trainer(bc_dummy_config, use_discrete)
55+
trainer.export_model()

ml-agents/mlagents/trainers/tests/test_bc.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,18 @@ def dummy_config():
3232
)
3333

3434

35-
def create_bc_trainer(dummy_config):
35+
def create_bc_trainer(dummy_config, is_discrete=False):
3636
mock_env = mock.Mock()
37-
mock_brain = mb.create_mock_3dball_brain()
38-
mock_braininfo = mb.create_mock_braininfo(num_agents=12, num_vector_observations=8)
37+
if is_discrete:
38+
mock_brain = mb.create_mock_pushblock_brain()
39+
mock_braininfo = mb.create_mock_braininfo(
40+
num_agents=12, num_vector_observations=70
41+
)
42+
else:
43+
mock_brain = mb.create_mock_3dball_brain()
44+
mock_braininfo = mb.create_mock_braininfo(
45+
num_agents=12, num_vector_observations=8
46+
)
3947
mb.setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
4048
env = mock_env()
4149

0 commit comments

Comments
 (0)