Skip to content

Commit e6df2dd

Browse files
committed
Small bug fix in Omiga smac wrapper.
1 parent 32dcf3a commit e6df2dd

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

og_marl/wrapped_environments/smacv1_omiga.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,29 @@ def __init__(
3434
map_name: str,
3535
seed: int = 0,
3636
):
37-
self.environment = StarCraft2Env(
37+
self._environment = StarCraft2Env(
3838
args=OmigaConf(map_name=map_name),
3939
seed=seed,
4040
)
41-
self.agents = [f"agent_{n}" for n in range(self.environment.n_agents)]
41+
self.agents = [f"agent_{n}" for n in range(self._environment.n_agents)]
4242

4343
self.num_agents = len(self.agents)
44-
self.num_actions = self.environment.n_actions
44+
self.num_actions = self._environment.n_actions
4545

4646
def reset(self) -> ResetReturn:
4747
"""Resets the env."""
4848
# Reset the environment
49-
self.environment.reset()
49+
self._environment.reset()
5050
self.done = False
5151

5252
# Get observation from env
53-
observations = self.environment.get_obs()
53+
observations = self._environment.get_obs()
5454
observations = {agent: observations[i] for i, agent in enumerate(self.agents)}
5555

5656
legal_actions = self._get_legal_actions()
5757
legals = {agent: legal_actions[i] for i, agent in enumerate(self.agents)}
5858

59-
env_state = self.environment.get_state(agent_id=0).astype("float32")
59+
env_state = self._environment.get_state(agent_id=0).astype("float32")
6060

6161
info = {"legals": legals, "state": env_state}
6262

@@ -69,7 +69,7 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
6969
for agent in self.agents:
7070
smac_actions.append(actions[agent])
7171

72-
o, g, r, d, i, ava = self.environment.step(smac_actions)
72+
o, g, r, d, i, ava = self._environment.step(smac_actions)
7373

7474
observations = {agent: o[i] for i, agent in enumerate(self.agents)}
7575
rewards = {
@@ -89,6 +89,6 @@ def _get_legal_actions(self) -> List[np.ndarray]:
8989
legal_actions = []
9090
for i, _ in enumerate(self.agents):
9191
legal_actions.append(
92-
np.array(self.environment.get_avail_agent_actions(i), dtype="float32")
92+
np.array(self._environment.get_avail_agent_actions(i), dtype="float32")
9393
)
9494
return legal_actions

0 commit comments

Comments
 (0)