Skip to content

Commit 72835e8

Browse files
authored
Hotfix 0.3.0b (#519)
* Fixes internal brain for Banana Imitation. * Fixes Discrete Control training for Imitation Learning. * Fixes Visual Observations in internal brain with non-square inputs.
1 parent 862543e commit 72835e8

File tree

9 files changed

+155
-169
lines changed

9 files changed

+155
-169
lines changed

python/unitytrainers/bc/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def __init__(self, brain, h_size=128, lr=1e-4, n_layers=2, m_size=128,
2222

2323
if brain.vector_action_space_type == "discrete":
2424
self.action_probs = tf.nn.softmax(self.policy)
25-
self.sample_action = tf.cast(tf.multinomial(self.policy, 1, name="action"), tf.int32)
25+
self.sample_action_float = tf.multinomial(self.policy, 1)
26+
self.sample_action_float = tf.identity(self.sample_action_float, name="action")
27+
self.sample_action = tf.cast(self.sample_action_float, tf.int32)
2628
self.true_action = tf.placeholder(shape=[None], dtype=tf.int32, name="teacher_action")
2729
self.action_oh = tf.one_hot(self.true_action, self.a_size)
2830
self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh)

python/unitytrainers/bc/trainer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def __init__(self, sess, env, brain_name, trainer_parameters, training, seed):
5454
self.stats = {'losses': [], 'episode_length': [], 'cumulative_reward': []}
5555

5656
self.training_buffer = Buffer()
57-
self.is_continuous = (env.brains[brain_name].vector_action_space_type == "continuous")
57+
self.is_continuous_action = (env.brains[brain_name].vector_action_space_type == "continuous")
58+
self.is_continuous_observation = (env.brains[brain_name].vector_observation_space_type == "continuous")
5859
self.use_observations = (env.brains[brain_name].number_visual_observations > 0)
5960
if self.use_observations:
6061
logger.info('Cannot use observations with imitation learning')
@@ -286,12 +287,16 @@ def update_model(self):
286287
end = (j + 1) * self.n_sequences
287288
batch_states = np.array(_buffer['vector_observations'][start:end])
288289
batch_actions = np.array(_buffer['actions'][start:end])
289-
feed_dict = {self.model.true_action: batch_actions.reshape([-1, self.brain.vector_action_space_size]),
290-
self.model.dropout_rate: 0.5,
290+
291+
feed_dict = {self.model.dropout_rate: 0.5,
291292
self.model.batch_size: self.n_sequences,
292293
self.model.sequence_length: self.sequence_length}
293-
if not self.is_continuous:
294-
feed_dict[self.model.vector_in] = batch_states.reshape([-1, 1])
294+
if self.is_continuous_action:
295+
feed_dict[self.model.true_action] = batch_actions.reshape([-1, self.brain.vector_action_space_size])
296+
else:
297+
feed_dict[self.model.true_action] = batch_actions.reshape([-1])
298+
if not self.is_continuous_observation:
299+
feed_dict[self.model.vector_in] = batch_states.reshape([-1, self.brain.num_stacked_vector_observations])
295300
else:
296301
feed_dict[self.model.vector_in] = batch_states.reshape([-1, self.brain.vector_observation_space_size *
297302
self.brain.num_stacked_vector_observations])

unity-environment/Assets/ML-Agents/Examples/BananaCollectors/BananaImitation.unity

Lines changed: 109 additions & 116 deletions
Large diffs are not rendered by default.

unity-environment/Assets/ML-Agents/Examples/BananaCollectors/Scripts/BananaAgent.cs

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,35 @@ public void MoveAgent(float[] act)
8383
Vector3 dirToGo = Vector3.zero;
8484
Vector3 rotateDir = Vector3.zero;
8585

86+
8687
if (!frozen)
8788
{
88-
dirToGo = transform.forward * Mathf.Clamp(act[0], -1f, 1f);
89-
rotateDir = transform.up * Mathf.Clamp(act[1], -1f, 1f);
90-
if (Mathf.Clamp(act[2], 0f, 1f) > 0.5f)
89+
bool shootCommand = false;
90+
if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous)
91+
{
92+
dirToGo = transform.forward * Mathf.Clamp(act[0], -1f, 1f);
93+
rotateDir = transform.up * Mathf.Clamp(act[1], -1f, 1f);
94+
shootCommand = Mathf.Clamp(act[2], 0f, 1f) > 0.5f;
95+
}
96+
else
97+
{
98+
switch ((int)(act[0]))
99+
{
100+
case 1:
101+
dirToGo = transform.forward;
102+
break;
103+
case 2:
104+
shootCommand = true;
105+
break;
106+
case 3:
107+
rotateDir = -transform.up;
108+
break;
109+
case 4:
110+
rotateDir = transform.up;
111+
break;
112+
}
113+
}
114+
if (shootCommand)
91115
{
92116
shoot = true;
93117
dirToGo *= 0.5f;
@@ -121,9 +145,9 @@ public void MoveAgent(float[] act)
121145
myLaser.transform.localScale = new Vector3(0f, 0f, 0f);
122146

123147
}
124-
125148
}
126149

150+
127151
void Freeze()
128152
{
129153
gameObject.tag = "frozenAgent";
@@ -182,8 +206,8 @@ public override void AgentReset()
182206
agentRB.velocity = Vector3.zero;
183207
bananas = 0;
184208
myLaser.transform.localScale = new Vector3(0f, 0f, 0f);
185-
transform.position = new Vector3(Random.Range(-myArea.range, myArea.range),
186-
2f, Random.Range(-myArea.range, myArea.range))
209+
transform.position = new Vector3(Random.Range(-myArea.range, myArea.range),
210+
2f, Random.Range(-myArea.range, myArea.range))
187211
+ area.transform.position;
188212
transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360)));
189213
}

unity-environment/Assets/ML-Agents/Examples/BananaCollectors/TFModels/BananaI.bytes.meta renamed to unity-environment/Assets/ML-Agents/Examples/BananaCollectors/TFModels/BananaIL.bytes.meta

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ public void OnInspector()
513513
pixels = 1;
514514
else
515515
pixels = 3;
516-
float[,,,] result = new float[batchSize, width, height, pixels];
516+
float[,,,] result = new float[batchSize, height, width, pixels];
517517

518518
for (int b = 0; b < batchSize; b++)
519519
{

unity-environment/ProjectSettings/EditorBuildSettings.asset

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,4 @@
44
EditorBuildSettings:
55
m_ObjectHideFlags: 0
66
serializedVersion: 2
7-
m_Scenes:
8-
- enabled: 0
9-
path: Assets/ML-Agents/Examples/GridWorld/GridWorld.unity
10-
guid: 7c777442467e245108558a5155153927
11-
- enabled: 0
12-
path: Assets/ML-Agents/Examples/Tennis/Tennis.unity
13-
guid: 25c0c9e81e55c4e129e1a5c0ac254100
14-
- enabled: 0
15-
path: Assets/ML-Agents/Examples/Banana/BananaImitation.unity
16-
guid: 3ae10073cde7641f488ef7c87862333a
17-
- enabled: 0
18-
path: Assets/ML-Agents/Examples/PushBlock/Scenes/PushBlock.unity
19-
guid: ae8cc75939e3e4d07a79c8c6a08b54f4
20-
- enabled: 0
21-
path: Assets/ML-Agents/Examples/3DBall/3DScene.unity
22-
guid: 6f62a2ccb3830437ea4e85a617e856b3
23-
- enabled: 0
24-
path: Assets/ML-Agents/Examples/3DBall/3DHardScene.unity
25-
guid: 35c41099ceec44889bdbe95ed86c97ac
26-
- enabled: 0
27-
path: Assets/ML-Agents/Examples/Banana/BananaRL.unity
28-
guid: 11583205ab5b74bb4bb1b9951cf9e437
29-
- enabled: 0
30-
path: Assets/ML-Agents/Examples/Basic/Scene.unity
31-
guid: cf1d119a8748d406e90ecb623b45f92f
32-
- enabled: 0
33-
path: Assets/ML-Agents/Examples/Bouncer/Bouncer.unity
34-
guid: 2c29359d4c9fe49219b21cd83e246596
35-
- enabled: 0
36-
path: Assets/ML-Agents/Examples/Crawler/Crawler.unity
37-
guid: 4cf841b0478fb4b33971627b40c6420b
38-
- enabled: 0
39-
path: Assets/ML-Agents/Examples/Hallway/Scenes/Hallway.unity
40-
guid: d6d6a33ed0e18459a8d61817d600978a
41-
- enabled: 0
42-
path: Assets/ML-Agents/Examples/Reacher/Scene.unity
43-
guid: e58a3c10c43de4b6b91b7149838d1dfb
7+
m_configObjects: {}

0 commit comments

Comments
 (0)