Skip to content

Commit a57d2e2

Browse files
authored
fix mock_brain (#2377)
fix mock_brain
1 parent e4d43a0 commit a57d2e2

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

ml-agents/mlagents/trainers/tests/mock_brain.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,12 @@ def create_buffer(brain_infos, brain_params, sequence_length):
129129
buffer[0]["prev_action"].append(current_brain_info.previous_vector_actions[0])
130130
buffer[0]["masks"].append(1.0)
131131
buffer[0]["advantages"].append(1.0)
132-
buffer[0]["action_probs"].append(np.ones(buffer[0]["actions"][0].shape))
132+
if brain_params.vector_action_space_type == "discrete":
133+
buffer[0]["action_probs"].append(
134+
np.ones(sum(brain_params.vector_action_space_size))
135+
)
136+
else:
137+
buffer[0]["action_probs"].append(np.ones(buffer[0]["actions"][0].shape))
133138
buffer[0]["actions_pre"].append(np.ones(buffer[0]["actions"][0].shape))
134139
buffer[0]["random_normal_epsilon"].append(
135140
np.ones(buffer[0]["actions"][0].shape)

0 commit comments

Comments
 (0)