Skip to content

Commit 28d5505

Browse files
authored
fix(cmarlin): fix random_policy for len(ready_env_id) < collector_env_num (#335)
1 parent 456679d commit 28d5505

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

lzero/policy/random_policy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,9 @@ def _forward_collect(
199199
if self._cfg.type in ['sampled_efficientzero']:
200200
roots_sampled_actions = roots.get_sampled_actions()
201201

202-
data_id = [i for i in range(active_collect_env_num)]
203-
output = {i: None for i in data_id}
204202
if ready_env_id is None:
205203
ready_env_id = np.arange(active_collect_env_num)
204+
output = {i: None for i in ready_env_id}
206205

207206
for i, env_id in enumerate(ready_env_id):
208207
distributions, value = roots_visit_count_distributions[i], roots_values[i]
@@ -238,7 +237,7 @@ def _forward_collect(
238237
}
239238
else:
240239
# ****** sample a random action from the legal action set ********
241-
random_action = int(np.random.choice(legal_actions[env_id], 1))
240+
random_action = int(np.random.choice(legal_actions[i], 1))
242241
# all items except action are formally obtained from MCTS
243242
output[env_id] = {
244243
'action': random_action,

0 commit comments

Comments
 (0)