Skip to content

Commit ab6e543

Browse files
authored
Permit AgentIDs other than str (#1071)
1 parent 4628eef commit ab6e543

24 files changed

+501
-95
lines changed

pettingzoo/test/api_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import re
24
import warnings
35
from collections import defaultdict
@@ -249,7 +251,7 @@ def test_observation_action_spaces(env, agent_0):
249251
)
250252
if (not isinstance(agent, str)) and agent != "env":
251253
warnings.warn(
252-
"Agent's are recommended to have numbered string names, like player_0"
254+
"Agents are recommended to have numbered string names, like player_0"
253255
)
254256
if not isinstance(agent, str) or not re.match(
255257
"[a-z]+_[0-9]+", agent
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from typing import Tuple, Union
2+
3+
import gymnasium
4+
import numpy as np
5+
6+
from pettingzoo import AECEnv
7+
from pettingzoo.utils import wrappers
8+
from pettingzoo.utils.agent_selector import agent_selector
9+
10+
11+
def env():
12+
env = raw_env()
13+
env = wrappers.AssertOutOfBoundsWrapper(env)
14+
env = wrappers.OrderEnforcingWrapper(env)
15+
return env
16+
17+
18+
def get_type(agent: Tuple[str, int]):
19+
return agent[0]
20+
21+
22+
class raw_env(AECEnv[Tuple[str, int], np.ndarray, Union[int, None]]):
23+
metadata = {"render_modes": ["human"], "name": "generated_agents_env_v0"}
24+
25+
def __init__(self, max_cycles=100, render_mode=None):
26+
super().__init__()
27+
self._obs_spaces = {}
28+
self._act_spaces = {}
29+
30+
# dummy state space, not actually used
31+
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
32+
self._state = self.state_space.sample()
33+
34+
self.types = []
35+
self._agent_counters = {}
36+
self.max_cycles = max_cycles
37+
self._seed()
38+
self.render_mode = render_mode
39+
for i in range(3):
40+
self.add_type()
41+
42+
def observation_space(self, agent):
43+
return self._obs_spaces[get_type(agent)]
44+
45+
def action_space(self, agent):
46+
return self._act_spaces[get_type(agent)]
47+
48+
def state(self) -> np.ndarray:
49+
return self._state
50+
51+
def observe(self, agent):
52+
return self.observation_space(agent).sample()
53+
54+
def add_type(self) -> str:
55+
type_id = len(self.types)
56+
num_actions = self.np_random.integers(3, 10)
57+
obs_size = self.np_random.integers(10, 50)
58+
obs_space = gymnasium.spaces.Box(low=0, high=1, shape=(obs_size,))
59+
act_space = gymnasium.spaces.Discrete(num_actions)
60+
new_type = f"type{type_id}"
61+
self.types.append(new_type)
62+
self._obs_spaces[new_type] = obs_space
63+
self._act_spaces[new_type] = act_space
64+
self._agent_counters[new_type] = 0
65+
return new_type
66+
67+
def add_agent(self, type):
68+
agent_id = self._agent_counters[type]
69+
self._agent_counters[type] += 1
70+
agent = (type, agent_id)
71+
self.agents.append(agent)
72+
self.terminations[agent] = False
73+
self.truncations[agent] = False
74+
self.rewards[agent] = 0
75+
self._cumulative_rewards[agent] = 0
76+
self.infos[agent] = {}
77+
return agent
78+
79+
def reset(self, seed=None, options=None):
80+
if seed is not None:
81+
self._seed(seed=seed)
82+
self.agents = []
83+
self.rewards = {}
84+
self._cumulative_rewards = {}
85+
self.terminations = {}
86+
self.truncations = {}
87+
self.infos = {}
88+
self.num_steps = 0
89+
90+
self._obs_spaces = {}
91+
self._act_spaces = {}
92+
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
93+
self._state = self.state_space.sample()
94+
95+
self.types = []
96+
self._agent_counters = {}
97+
for i in range(3):
98+
self.add_type()
99+
for i in range(5):
100+
self.add_agent(self.np_random.choice(self.types))
101+
102+
self._agent_selector = agent_selector(self.agents)
103+
self.agent_selection = self._agent_selector.reset()
104+
105+
# seed observation and action spaces
106+
for i, agent in enumerate(self.agents):
107+
self.observation_space(agent).seed(seed)
108+
for i, agent in enumerate(self.agents):
109+
self.action_space(agent).seed(seed)
110+
111+
def _seed(self, seed=None):
112+
self.np_random, _ = gymnasium.utils.seeding.np_random(seed)
113+
114+
def step(self, action):
115+
if (
116+
self.terminations[self.agent_selection]
117+
or self.truncations[self.agent_selection]
118+
):
119+
return self._was_dead_step(action)
120+
121+
self._clear_rewards()
122+
self._cumulative_rewards[self.agent_selection] = 0
123+
124+
if self._agent_selector.is_last():
125+
for i in range(5):
126+
if self.np_random.random() < 0.1:
127+
if self.np_random.random() < 0.1:
128+
type = self.add_type()
129+
else:
130+
type = self.np_random.choice(self.types)
131+
132+
agent = self.add_agent(type)
133+
if len(self.agents) >= 20:
134+
self.terminations[self.np_random.choice(self.agents)] = True
135+
136+
if self._agent_selector.is_last():
137+
self.num_steps += 1
138+
139+
if self.num_steps > self.max_cycles:
140+
for agent in self.agents:
141+
self.truncations[agent] = True
142+
143+
self.rewards[self.agents[self.np_random.choice(len(self.agents))]] = 1
144+
145+
self._state = self.state_space.sample()
146+
147+
self._accumulate_rewards()
148+
self._deads_step_first()
149+
if self.render_mode == "human":
150+
self.render()
151+
152+
def render(self):
153+
if self.render_mode is None:
154+
gymnasium.logger.warn(
155+
"You are calling render method without specifying any render mode."
156+
)
157+
else:
158+
print(self.agents)
159+
160+
def close(self):
161+
pass

pettingzoo/test/example_envs/generated_agents_env_v0.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
import gymnasium
24
import numpy as np
35

@@ -17,7 +19,7 @@ def get_type(agent):
1719
return agent[: agent.rfind("_")]
1820

1921

20-
class raw_env(AECEnv):
22+
class raw_env(AECEnv[str, np.ndarray, Union[int, None]]):
2123
metadata = {"render_modes": ["human"], "name": "generated_agents_env_v0"}
2224

2325
def __init__(self, max_cycles=100, render_mode=None):
@@ -26,7 +28,7 @@ def __init__(self, max_cycles=100, render_mode=None):
2628
self._act_spaces = {}
2729

2830
# dummy state space, not actually used
29-
self.state_space = gymnasium.spaces.MultiDiscrete((10, 10))
31+
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
3032
self._state = self.state_space.sample()
3133

3234
self.types = []
@@ -87,7 +89,7 @@ def reset(self, seed=None, options=None):
8789

8890
self._obs_spaces = {}
8991
self._act_spaces = {}
90-
self.state_space = gymnasium.spaces.MultiDiscrete((10, 10))
92+
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
9193
self._state = self.state_space.sample()
9294

9395
self.types = []
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from typing import Tuple, Union
2+
3+
import gymnasium
4+
import numpy as np
5+
from gymnasium.utils import seeding
6+
7+
from pettingzoo import ParallelEnv
8+
from pettingzoo.utils import conversions, wrappers
9+
10+
11+
def env(**kwargs):
12+
env = raw_env(**kwargs)
13+
env = wrappers.AssertOutOfBoundsWrapper(env)
14+
env = wrappers.OrderEnforcingWrapper(env)
15+
return env
16+
17+
18+
def raw_env(**kwargs):
19+
return conversions.parallel_to_aec(parallel_env(**kwargs))
20+
21+
22+
def get_type(agent: Tuple[str, int]):
23+
return agent[0]
24+
25+
26+
class parallel_env(ParallelEnv[Tuple[str, int], np.ndarray, Union[int, None]]):
27+
metadata = {"render_modes": ["human"], "name": "generated_agents_parallel_v0"}
28+
29+
def __init__(self, max_cycles=100, render_mode=None):
30+
super().__init__()
31+
self._obs_spaces = {}
32+
self._act_spaces = {}
33+
34+
# dummy state space, not actually used
35+
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
36+
self._state = self.state_space.sample()
37+
38+
self.types = []
39+
self._agent_counters = {}
40+
self.max_cycles = max_cycles
41+
self.rng_seed = None
42+
self._seed()
43+
self.render_mode = render_mode
44+
for i in range(3):
45+
self.add_type()
46+
47+
def observation_space(self, agent):
48+
return self._obs_spaces[get_type(agent)]
49+
50+
def action_space(self, agent):
51+
return self._act_spaces[get_type(agent)]
52+
53+
def state(self) -> np.ndarray:
54+
return self._state
55+
56+
def observe(self, agent):
57+
return self.observation_space(agent).sample()
58+
59+
def add_type(self) -> str:
60+
type_id = len(self.types)
61+
num_actions = self.np_random.integers(3, 10)
62+
obs_size = self.np_random.integers(10, 50)
63+
obs_space = gymnasium.spaces.Box(low=0, high=1, shape=(obs_size,))
64+
act_space = gymnasium.spaces.Discrete(num_actions)
65+
obs_space.seed(self.rng_seed)
66+
act_space.seed(self.rng_seed)
67+
new_type = f"type{type_id}"
68+
self.types.append(new_type)
69+
self._obs_spaces[new_type] = obs_space
70+
self._act_spaces[new_type] = act_space
71+
self._agent_counters[new_type] = 0
72+
return new_type
73+
74+
def add_agent(self, type: str):
75+
agent_id = self._agent_counters[type]
76+
self._agent_counters[type] += 1
77+
agent_name = (type, agent_id)
78+
self.agents.append(agent_name)
79+
return agent_name
80+
81+
def reset(self, seed=None, options=None):
82+
self.rng_seed = seed
83+
84+
if seed is not None:
85+
self._seed(seed=seed)
86+
self.num_steps = 0
87+
88+
# Reset spaces and types
89+
self._obs_spaces = {}
90+
self._act_spaces = {}
91+
self.state_space = gymnasium.spaces.MultiDiscrete([10, 10])
92+
self._state = self.state_space.sample()
93+
94+
self.types = []
95+
self._agent_counters = {}
96+
for i in range(3):
97+
self.add_type()
98+
99+
# Add agents
100+
self.agents = []
101+
for i in range(5):
102+
self.add_agent(self.np_random.choice(self.types))
103+
104+
# seed observation and action spaces
105+
for i, agent in enumerate(self.agents):
106+
self.observation_space(agent).seed(seed)
107+
for i, agent in enumerate(self.agents):
108+
self.action_space(agent).seed(seed)
109+
110+
return {agent: self.observe(agent) for agent in self.agents}, {
111+
agent: {} for agent in self.agents
112+
}
113+
114+
def _seed(self, seed=None):
115+
self.np_random, _ = seeding.np_random(seed)
116+
117+
def step(self, actions):
118+
truncated = self.num_steps >= self.max_cycles
119+
for agent in self.agents:
120+
assert agent in actions
121+
all_truncations = {agent: truncated for agent in self.agents}
122+
all_terminations = {agent: False for agent in self.agents}
123+
if not truncated:
124+
for i in range(6):
125+
if self.np_random.random() < 0.1 and len(self.agents) >= 10:
126+
all_terminations[
127+
self.agents[self.np_random.choice(len(self.agents))]
128+
] = True
129+
130+
for i in range(3):
131+
if self.np_random.random() < 0.1:
132+
if self.np_random.random() < 0.1:
133+
type = self.add_type()
134+
else:
135+
type = self.np_random.choice(self.types)
136+
137+
new_agent = self.add_agent(type)
138+
all_terminations[new_agent] = False
139+
all_truncations[new_agent] = False
140+
141+
all_infos = {agent: {} for agent in self.agents}
142+
all_rewards = {agent: 0 for agent in self.agents}
143+
all_rewards[self.agents[self.np_random.choice(len(self.agents))]] = 1
144+
all_observes = {agent: self.observe(agent) for agent in self.agents}
145+
self.agents = [
146+
agent
147+
for agent in self.agents
148+
if not (all_truncations[agent] or all_terminations[agent])
149+
]
150+
151+
if self.render_mode == "human":
152+
self.render()
153+
return all_observes, all_rewards, all_terminations, all_truncations, all_infos
154+
155+
def render(self):
156+
if self.render_mode is None:
157+
gymnasium.logger.warn(
158+
"You are calling render method without specifying any render mode."
159+
)
160+
else:
161+
print(self.agents)
162+
163+
def close(self):
164+
pass

0 commit comments

Comments
 (0)