Skip to content

Commit 556b2ec

Browse files
authored
polish(pu): optimize the implementation of transformation from action_mask to legal_actions (#466)
1 parent eeca7d4 commit 556b2ec

20 files changed

+35
-35
lines changed

lzero/mcts/buffer/game_buffer_efficientzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
185185
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
186186
)
187187

188-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
188+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
189189

190190
# ==============================================================
191191
# EfficientZero related core code
@@ -344,7 +344,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
344344
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
345345
)
346346

347-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
347+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
348348
with torch.no_grad():
349349
policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)
350350
# split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors

lzero/mcts/buffer/game_buffer_muzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
597597
[-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(transition_batch_size)
598598
]
599599
else:
600-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
600+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
601601

602602
with torch.no_grad():
603603
policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)
@@ -755,7 +755,7 @@ def _compute_target_policy_non_reanalyzed(
755755
[-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size)
756756
]
757757
else:
758-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
758+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
759759

760760
with torch.no_grad():
761761
policy_index = 0

lzero/mcts/buffer/game_buffer_rezero_ez.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
158158
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
159159
]
160160
else:
161-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
161+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
162162

163163
with torch.no_grad():
164164
policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)

lzero/mcts/buffer/game_buffer_rezero_mz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
230230
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
231231
]
232232
else:
233-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
233+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
234234

235235
with torch.no_grad():
236236
policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)

lzero/mcts/buffer/game_buffer_sampled_efficientzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
272272
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
273273
]
274274
else:
275-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
275+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
276276

277277
batch_target_values, batch_value_prefixs = [], []
278278
with torch.no_grad():
@@ -452,7 +452,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
452452
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
453453
]
454454
else:
455-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
455+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
456456

457457
with torch.no_grad():
458458
policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)

lzero/mcts/buffer/game_buffer_sampled_muzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
272272
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
273273
]
274274
else:
275-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
275+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
276276

277277
batch_target_values, batch_rewards = [], []
278278
with torch.no_grad():
@@ -437,7 +437,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
437437
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
438438
]
439439
else:
440-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
440+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
441441

442442
with torch.no_grad():
443443
policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)

lzero/mcts/buffer/game_buffer_sampled_unizero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
485485
[-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(transition_batch_size)
486486
]
487487
else:
488-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
488+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
489489

490490
# NOTE: TODO
491491
model.world_model.reanalyze_phase = True
@@ -658,7 +658,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
658658
[-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(transition_batch_size)
659659
]
660660
else:
661-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
661+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
662662

663663
batch_target_values, batch_rewards = [], []
664664
with torch.no_grad():

lzero/mcts/buffer/game_buffer_unizero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
431431
[-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size)
432432
]
433433
else:
434-
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
434+
legal_actions = [np.nonzero(action_mask[j])[0].tolist() for j in range(transition_batch_size)]
435435

436436
# NOTE: check the effect of reanalyze_phase
437437
model.world_model.reanalyze_phase = True

lzero/mcts/tests/eval_tree_speed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def ptree_func(policy_config, num_simulations):
116116
assert len(action_mask[0]) == action_space_size
117117

118118
action_num = [int(np.array(action_mask[i]).sum()) for i in range(env_nums)]
119-
legal_actions_list = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(env_nums)]
119+
legal_actions_list = [np.nonzero(action_mask[j])[0].tolist() for j in range(env_nums)]
120120
to_play = [np.random.randint(1, 3) for i in range(env_nums)]
121121
assert len(to_play) == batch_size
122122
# ============================================ptree=====================================#
@@ -212,7 +212,7 @@ def ctree_func(policy_config, num_simulations):
212212
assert len(action_mask[0]) == action_space_size
213213

214214
action_num = [int(np.array(action_mask[i]).sum()) for i in range(env_nums)]
215-
legal_actions_list = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(env_nums)]
215+
legal_actions_list = [np.nonzero(action_mask[j])[0].tolist() for j in range(env_nums)]
216216
to_play = [np.random.randint(1, 3) for i in range(env_nums)]
217217
assert len(to_play) == batch_size
218218
# ============================================ctree=====================================#

lzero/mcts/tests/test_mcts_ctree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None)
143143
action_num = [
144144
int(np.array(action_mask[i]).sum()) for i in range(env_nums)
145145
] # [3, 3, 5, 4, 3, 3, 6, 6, 3, 6, 6, 5, 2, 5, 1, 4]
146-
legal_actions_list = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(env_nums)]
146+
legal_actions_list = [np.nonzero(action_mask[j])[0].tolist() for j in range(env_nums)]
147147
# legal_actions_list =
148148
# [[3, 5, 6], [0, 3, 6], [0, 1, 4, 6, 8], [0, 3, 4, 5],
149149
# [2, 5, 8], [1, 2, 4], [0, 2, 3, 4, 7, 8], [0, 1, 2, 3, 4, 8],

0 commit comments

Comments
 (0)