Skip to content

Commit b89f793

Browse files
committed
Small fix in OMAR MAMuJoCo wrapper.
1 parent d5e79af commit b89f793

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

og_marl/wrapped_environments/mamujoco_omar.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def reset(self) -> ResetReturn:
3939
observations = self._environment.get_obs()
4040

4141
observations = {
42-
agent: observations[i].astype("float32") for i, agent in enumerate(self.possible_agents)
42+
agent: observations[i].astype("float32") for i, agent in enumerate(self.agents)
4343
}
4444

4545
info = {"state": self._environment.get_state()}
@@ -48,20 +48,20 @@ def reset(self) -> ResetReturn:
4848

4949
def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
5050
mujoco_actions = []
51-
for agent in self.possible_agents:
51+
for agent in self.agents:
5252
mujoco_actions.append(actions[agent])
5353

5454
reward, done, info = self._environment.step(mujoco_actions)
5555

56-
terminals = {agent: done for agent in self.possible_agents}
57-
trunctations = {agent: False for agent in self.possible_agents}
56+
terminals = {agent: done for agent in self.agents}
57+
trunctations = {agent: False for agent in self.agents}
5858

59-
rewards = {agent: reward for agent in self.possible_agents}
59+
rewards = {agent: reward for agent in self.agents}
6060

6161
observations = self._environment.get_obs()
6262

6363
observations = {
64-
agent: observations[i].astype("float32") for i, agent in enumerate(self.possible_agents)
64+
agent: observations[i].astype("float32") for i, agent in enumerate(self.agents)
6565
}
6666

6767
info = {}

0 commit comments

Comments
 (0)