Skip to content

Commit ab8d6e8

Browse files
authored
Merge pull request #937 from Unity-Technologies/release-v0.4-fix-curiosity-odd
Hotfix - Curiosity & ODD
2 parents db7cdf2 + f497a27 commit ab8d6e8

File tree

2 files changed

+57
-18
lines changed

2 files changed

+57
-18
lines changed

python/unityagents/brain.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
class BrainInfo:
55
def __init__(self, visual_observation, vector_observation, text_observations, memory=None,
6-
reward=None, agents=None, local_done=None,
7-
vector_action=None, text_action=None, max_reached=None):
6+
reward=None, agents=None, local_done=None,
7+
vector_action=None, text_action=None, max_reached=None):
88
"""
99
Describes experience at current step of all agents linked to a brain.
1010
"""
@@ -49,10 +49,10 @@ def __str__(self):
4949
Vector Action space type: {5}
5050
Vector Action space size (per agent): {6}
5151
Vector Action descriptions: {7}'''.format(self.brain_name,
52-
str(self.number_visual_observations),
53-
self.vector_observation_space_type,
54-
str(self.vector_observation_space_size),
55-
str(self.num_stacked_vector_observations),
56-
self.vector_action_space_type,
57-
str(self.vector_action_space_size),
58-
', '.join(self.vector_action_descriptions))
52+
str(self.number_visual_observations),
53+
self.vector_observation_space_type,
54+
str(self.vector_observation_space_size),
55+
str(self.num_stacked_vector_observations),
56+
self.vector_action_space_type,
57+
str(self.vector_action_space_size),
58+
', '.join(self.vector_action_descriptions))

python/unitytrainers/ppo/trainer.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import tensorflow as tf
1010

11-
from unityagents import AllBrainInfo
11+
from unityagents import AllBrainInfo, BrainInfo
1212
from unitytrainers.buffer import Buffer
1313
from unitytrainers.ppo.models import PPOModel
1414
from unitytrainers.trainer import UnityTrainerException, Trainer
@@ -196,22 +196,61 @@ def take_action(self, all_brain_info: AllBrainInfo):
196196
else:
197197
return run_out[self.model.output], None, None, run_out
198198

199+
def construct_curr_info(self, next_info: BrainInfo) -> BrainInfo:
200+
"""
201+
Constructs a BrainInfo which contains the most recent previous experiences for all agents info
202+
which correspond to the agents in a provided next_info.
203+
:BrainInfo next_info: A t+1 BrainInfo.
204+
:return: curr_info: Reconstructed BrainInfo to match agents of next_info.
205+
"""
206+
visual_observations = [[]]
207+
vector_observations = []
208+
text_observations = []
209+
memories = []
210+
rewards = []
211+
local_dones = []
212+
max_reacheds = []
213+
agents = []
214+
prev_vector_actions = []
215+
prev_text_actions = []
216+
for agent_id in next_info.agents:
217+
agent_brain_info = self.training_buffer[agent_id].last_brain_info
218+
agent_index = agent_brain_info.agents.index(agent_id)
219+
if agent_brain_info is None:
220+
agent_brain_info = next_info
221+
for i in range(len(next_info.visual_observations)):
222+
visual_observations[i].append(agent_brain_info.visual_observations[i][agent_index])
223+
vector_observations.append(agent_brain_info.vector_observations[agent_index])
224+
text_observations.append(agent_brain_info.text_observations[agent_index])
225+
if self.use_recurrent:
226+
memories.append(agent_brain_info.memories[agent_index])
227+
rewards.append(agent_brain_info.rewards[agent_index])
228+
local_dones.append(agent_brain_info.local_done[agent_index])
229+
max_reacheds.append(agent_brain_info.max_reached[agent_index])
230+
agents.append(agent_brain_info.agents[agent_index])
231+
prev_vector_actions.append(agent_brain_info.previous_vector_actions[agent_index])
232+
prev_text_actions.append(agent_brain_info.previous_text_actions[agent_index])
233+
curr_info = BrainInfo(visual_observations, vector_observations, text_observations, memories, rewards,
234+
agents, local_dones, prev_vector_actions, prev_text_actions, max_reacheds)
235+
return curr_info
236+
199237
def generate_intrinsic_rewards(self, curr_info, next_info):
200238
"""
201239
Generates intrinsic reward used for Curiosity-based training.
202-
:param curr_info: Current BrainInfo.
203-
:param next_info: Next BrainInfo.
240+
:BrainInfo curr_info: Current BrainInfo.
241+
:BrainInfo next_info: Next BrainInfo.
204242
:return: Intrinsic rewards for all agents.
205243
"""
206244
if self.use_curiosity:
207-
if curr_info.agents != next_info.agents:
208-
raise UnityTrainerException("Training with Curiosity-driven exploration"
209-
" and On-Demand Decision making is currently not supported.")
210-
feed_dict = {self.model.batch_size: len(curr_info.vector_observations), self.model.sequence_length: 1}
245+
feed_dict = {self.model.batch_size: len(next_info.vector_observations), self.model.sequence_length: 1}
211246
if self.is_continuous_action:
212247
feed_dict[self.model.output] = next_info.previous_vector_actions
213248
else:
214249
feed_dict[self.model.action_holder] = next_info.previous_vector_actions.flatten()
250+
251+
if curr_info.agents != next_info.agents:
252+
curr_info = self.construct_curr_info(next_info)
253+
215254
if self.use_visual_obs:
216255
for i in range(len(curr_info.visual_observations)):
217256
feed_dict[self.model.visual_in[i]] = curr_info.visual_observations[i]
@@ -262,12 +301,12 @@ def add_experiences(self, curr_all_info: AllBrainInfo, next_all_info: AllBrainIn
262301
curr_info = curr_all_info[self.brain_name]
263302
next_info = next_all_info[self.brain_name]
264303

265-
intrinsic_rewards = self.generate_intrinsic_rewards(curr_info, next_info)
266-
267304
for agent_id in curr_info.agents:
268305
self.training_buffer[agent_id].last_brain_info = curr_info
269306
self.training_buffer[agent_id].last_take_action_outputs = take_action_outputs
270307

308+
intrinsic_rewards = self.generate_intrinsic_rewards(curr_info, next_info)
309+
271310
for agent_id in next_info.agents:
272311
stored_info = self.training_buffer[agent_id].last_brain_info
273312
stored_take_action_outputs = self.training_buffer[agent_id].last_take_action_outputs

0 commit comments

Comments
 (0)