Skip to content

Commit c6e6974

Browse files
committed
Fix bug in omiga mujoco wrapper.
1 parent 39f5428 commit c6e6974

File tree

6 files changed

+90
-9
lines changed

6 files changed

+90
-9
lines changed

og_marl/environments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def get_environment(source: str, env_name: str, scenario: str, seed: int = 42) -
2121
elif env_name == "mamujoco" and source == "og_marl":
2222
from og_marl.wrapped_environments.mamujoco import MAMuJoCo
2323

24+
return MAMuJoCo(scenario, seed=seed)
25+
elif env_name == "mamujoco" and source == "omar":
26+
from og_marl.wrapped_environments.mamujoco_omar import MAMuJoCo
27+
2428
return MAMuJoCo(scenario, seed=seed)
2529
elif env_name == "gymnasium_mamujoco":
2630
from og_marl.wrapped_environments.gymnasium_mamujoco import WrappedGymnasiumMAMuJoCo

og_marl/tf2_systems/offline/configs/continuous_bc.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ wandb_project: og-marl
44
training_steps: 5e5
55

66
task:
7-
source: og_marl
7+
source: omiga
88
env: mamujoco
9-
scenario: 2halfcheetah
10-
dataset: Good
9+
scenario: 3hopper
10+
dataset: Expert
1111

1212
replay:
1313
sequence_length: 20

og_marl/tf2_systems/offline/configs/iddpg_bc.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ wandb_project: og-marl
44
training_steps: 5e5
55

66
task:
7-
source: my_datasets
8-
env: gymnasium_mamujoco
9-
scenario: 2reacher
10-
dataset: replay
7+
source: omiga
8+
env: mamujoco
9+
scenario: 3hopper
10+
dataset: Expert
1111

1212
replay:
1313
sequence_length: 20

og_marl/wrapped_environments/mamujoco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_mamujoco_args(scenario: str) -> Dict[str, Any]:
2323
env_args = {
2424
"agent_obsk": 1,
2525
"episode_limit": 1000,
26-
"global_categories": "qvel,qpos",
26+
# "global_categories": "qvel,qpos",
2727
}
2828
if scenario.lower() == "4ant":
2929
env_args["scenario"] = "Ant-v2"
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from typing import Any, Dict
2+
3+
import numpy as np
4+
5+
from og_marl.custom_environments.multiagent_mujoco.mujoco_multi import MujocoMulti
6+
7+
from og_marl.wrapped_environments.base import BaseEnvironment, ResetReturn, StepReturn
8+
9+
10+
class MAMuJoCo(BaseEnvironment):
11+
12+
"""Environment wrapper Multi-Agent MuJoCo."""
13+
14+
def __init__(self, scenario: str, seed=None):
15+
env_args = self._get_mamujoco_args(scenario)
16+
17+
self._environment = MujocoMulti(env_args=env_args)
18+
19+
self.possible_agents = [f"agent_{n}" for n in range(self._environment.n_agents)]
20+
self._num_actions = self._environment.n_actions
21+
22+
self.max_episode_length = 1000
23+
24+
def _get_mamujoco_args(self, scenario: str) -> Dict[str, Any]:
25+
env_args = {
26+
"agent_obsk": 0,
27+
"episode_limit": 1000,
28+
}
29+
if scenario.lower() == "2halfcheetah":
30+
env_args["scenario"] = "HalfCheetah-v2"
31+
env_args["agent_conf"] = "2x3"
32+
else:
33+
raise ValueError("Not a valid omar mamujoco scenario.")
34+
return env_args
35+
36+
def reset(self) -> ResetReturn:
37+
self._environment.reset()
38+
39+
observations = self._environment.get_obs()
40+
41+
observations = {
42+
agent: observations[i].astype("float32") for i, agent in enumerate(self.possible_agents)
43+
}
44+
45+
info = {"state": self._environment.get_state()}
46+
47+
return observations, info
48+
49+
def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
50+
mujoco_actions = []
51+
for agent in self.possible_agents:
52+
mujoco_actions.append(actions[agent])
53+
54+
reward, done, info = self._environment.step(mujoco_actions)
55+
56+
terminals = {agent: done for agent in self.possible_agents}
57+
trunctations = {agent: False for agent in self.possible_agents}
58+
59+
rewards = {agent: reward for agent in self.possible_agents}
60+
61+
observations = self._environment.get_obs()
62+
63+
observations = {
64+
agent: observations[i].astype("float32") for i, agent in enumerate(self.possible_agents)
65+
}
66+
67+
info = {}
68+
info["state"] = self._environment.get_state()
69+
70+
return observations, rewards, terminals, trunctations, info # type: ignore
71+
72+
def __getattr__(self, name: str) -> Any:
73+
"""Expose any other attributes of the underlying environment."""
74+
if hasattr(self.__class__, name):
75+
return self.__getattribute__(name)
76+
else:
77+
return getattr(self._environment, name)

og_marl/wrapped_environments/mamujoco_omiga.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def add_agent_id_and_normalise(self, observations):
9393
one_hot[i] = 1
9494
agent_obs = observations[agent].astype("float32")
9595
agent_obs = np.concatenate([agent_obs, one_hot], axis=-1)
96-
agent_obs = agent_obs - np.mean(agent_obs) / np.std(agent_obs)
96+
agent_obs = (agent_obs - np.mean(agent_obs)) / np.std(agent_obs)
9797
observations[agent] = agent_obs
9898
return observations
9999

0 commit comments

Comments
 (0)