@@ -34,29 +34,29 @@ def __init__(
34
34
map_name : str ,
35
35
seed : int = 0 ,
36
36
):
37
- self .environment = StarCraft2Env (
37
+ self ._environment = StarCraft2Env (
38
38
args = OmigaConf (map_name = map_name ),
39
39
seed = seed ,
40
40
)
41
- self .agents = [f"agent_{ n } " for n in range (self .environment .n_agents )]
41
+ self .agents = [f"agent_{ n } " for n in range (self ._environment .n_agents )]
42
42
43
43
self .num_agents = len (self .agents )
44
- self .num_actions = self .environment .n_actions
44
+ self .num_actions = self ._environment .n_actions
45
45
46
46
def reset (self ) -> ResetReturn :
47
47
"""Resets the env."""
48
48
# Reset the environment
49
- self .environment .reset ()
49
+ self ._environment .reset ()
50
50
self .done = False
51
51
52
52
# Get observation from env
53
- observations = self .environment .get_obs ()
53
+ observations = self ._environment .get_obs ()
54
54
observations = {agent : observations [i ] for i , agent in enumerate (self .agents )}
55
55
56
56
legal_actions = self ._get_legal_actions ()
57
57
legals = {agent : legal_actions [i ] for i , agent in enumerate (self .agents )}
58
58
59
- env_state = self .environment .get_state (agent_id = 0 ).astype ("float32" )
59
+ env_state = self ._environment .get_state (agent_id = 0 ).astype ("float32" )
60
60
61
61
info = {"legals" : legals , "state" : env_state }
62
62
@@ -69,7 +69,7 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
69
69
for agent in self .agents :
70
70
smac_actions .append (actions [agent ])
71
71
72
- o , g , r , d , i , ava = self .environment .step (smac_actions )
72
+ o , g , r , d , i , ava = self ._environment .step (smac_actions )
73
73
74
74
observations = {agent : o [i ] for i , agent in enumerate (self .agents )}
75
75
rewards = {
@@ -89,6 +89,6 @@ def _get_legal_actions(self) -> List[np.ndarray]:
89
89
legal_actions = []
90
90
for i , _ in enumerate (self .agents ):
91
91
legal_actions .append (
92
- np .array (self .environment .get_avail_agent_actions (i ), dtype = "float32" )
92
+ np .array (self ._environment .get_avail_agent_actions (i ), dtype = "float32" )
93
93
)
94
94
return legal_actions
0 commit comments