Skip to content

Commit 9939b7e

Browse files
verisimilidude2quaquelEwoutH
authored
Fix: AgentSet initialization should not require explicit random number generator (#2789)
This PR modifies the AgentSet initialization to automatically use the random number generator from the model of the first agent when no explicit random parameter is provided, eliminating unnecessary UserWarnings in typical usage. Previously, AgentSet initialization would raise a UserWarning whenever random=None, even though in most cases the agents being added to the set already have access to a seeded random number generator through their model. This resulted in users needing to explicitly pass random=model.random in every AgentSet creation, creating unnecessary boilerplate. Co-authored-by: Jan Kwakkel <[email protected]> Co-authored-by: Ewout ter Hoeven <[email protected]>
1 parent 8bd1739 commit 9939b7e

File tree

2 files changed

+59
-34
lines changed

2 files changed

+59
-34
lines changed

mesa/agent.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,30 +162,39 @@ class AgentSet(MutableSet, Sequence):
162162
which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet.
163163
164164
Notes:
165-
A `UserWarning` is issued if `random=None`. You can resolve this warning by explicitly
166-
passing a random number generator. In most cases, this will be the seeded random number
167-
generator in the model. So, you would do `random=self.random` in a `Model` or `Agent` instance.
165+
If random is None then the random number generator in the model of the first agent is used.
166+
If the agents list is empty and random is also None a user warning is issued and the AgentSet
167+
is an empty list and a default random number generator. This can make models non-reproducible.
168+
If your code may create an AgentSet with no agents please pass a random number generator explicitly.
168169
169170
"""
170171

171-
def __init__(self, agents: Iterable[Agent], random: Random | None = None):
172+
def __init__(
173+
self,
174+
agents: Iterable[Agent],
175+
random: Random | None = None,
176+
):
172177
"""Initializes the AgentSet with a collection of agents and a reference to the model.
173178
174179
Args:
175180
agents (Iterable[Agent]): An iterable of Agent objects to be included in the set.
176-
random (Random): the random number generator
181+
random (Random | np.random.Generator | None): the random number generator
177182
"""
178-
if random is None:
183+
self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents))
184+
if (len(self._agents) == 0) and random is None:
179185
warnings.warn(
180-
"Random number generator not specified, this can make models non-reproducible. Please pass a random number generator explicitly",
186+
"No Agents specified in creation of AgentSet and no random number generator specified. "
187+
"This can make models non-reproducible. Please pass a random number generator explicitly",
181188
UserWarning,
182189
stacklevel=2,
183190
)
184-
random = (
185-
Random()
186-
) # FIXME see issue 1981, how to get the central rng from model
187-
self.random = random
188-
self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents))
191+
random = Random()
192+
193+
if random is not None:
194+
self.random = random
195+
else:
196+
# all agents in an AgentSet should share the same model, just take it from first
197+
self.random = self._agents.keys().__next__().model.random
189198

190199
def __len__(self) -> int:
191200
"""Return the number of agents in the AgentSet."""

tests/test_agent.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_agentset():
6262
model = Model()
6363
agents = [AgentTest(model) for _ in range(10)]
6464

65-
agentset = AgentSet(agents, random=model.random)
65+
agentset = AgentSet(agents)
6666

6767
assert agents[0] in agentset
6868
assert len(agentset) == len(agents)
@@ -118,7 +118,7 @@ def test_function(agent):
118118

119119
# because AgentSet uses weakrefs, we need hard refs as well....
120120
other_agents, another_set = pickle.loads( # noqa: S301
121-
pickle.dumps([agents, AgentSet(agents, random=model.random)])
121+
pickle.dumps([agents, AgentSet(agents)])
122122
)
123123
assert all(
124124
a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents)
@@ -131,17 +131,33 @@ def test_agentset_initialization():
131131
model = Model()
132132
empty_agentset = AgentSet([], random=model.random)
133133
assert len(empty_agentset) == 0
134+
with pytest.warns(UserWarning):
135+
empty_agentset2 = AgentSet([])
136+
assert len(empty_agentset2) == 0
134137

135138
agents = [AgentTest(model) for _ in range(10)]
136-
agentset = AgentSet(agents, random=model.random)
139+
agentset = AgentSet(agents)
137140
assert len(agentset) == 10
138141

139142

143+
def test_agentset_initialization_w_random():
144+
"""Test agentset initialization."""
145+
model = Model()
146+
empty_agentset = AgentSet([], random=model.random)
147+
assert len(empty_agentset) == 0
148+
assert empty_agentset.random == model.random
149+
150+
agents = [AgentTest(model) for _ in range(10)]
151+
agentset = AgentSet(agents)
152+
assert len(agentset) == 10
153+
assert agentset.random == model.random
154+
155+
140156
def test_agentset_serialization():
141157
"""Test pickleability of agentset."""
142158
model = Model()
143159
agents = [AgentTest(model) for _ in range(5)]
144-
agentset = AgentSet(agents, random=model.random)
160+
agentset = AgentSet(agents)
145161

146162
serialized = pickle.dumps(agentset)
147163
deserialized = pickle.loads(serialized) # noqa: S301
@@ -156,7 +172,7 @@ def test_agent_membership():
156172
"""Test agent membership in AgentSet."""
157173
model = Model()
158174
agents = [AgentTest(model) for _ in range(5)]
159-
agentset = AgentSet(agents, random=model.random)
175+
agentset = AgentSet(agents)
160176

161177
assert agents[0] in agentset
162178
assert AgentTest(model) not in agentset
@@ -218,7 +234,7 @@ def test_agentset_get_item():
218234
"""Test integer based access to AgentSet."""
219235
model = Model()
220236
agents = [AgentTest(model) for _ in range(10)]
221-
agentset = AgentSet(agents, random=model.random)
237+
agentset = AgentSet(agents)
222238

223239
assert agentset[0] == agents[0]
224240
assert agentset[-1] == agents[-1]
@@ -232,7 +248,7 @@ def test_agentset_do_str():
232248
"""Test AgentSet.do with str."""
233249
model = Model()
234250
agents = [AgentTest(model) for _ in range(10)]
235-
agentset = AgentSet(agents, random=model.random)
251+
agentset = AgentSet(agents)
236252

237253
with pytest.raises(AttributeError):
238254
agentset.do("non_existing_method")
@@ -245,7 +261,7 @@ def test_agentset_do_str():
245261
n = 10
246262
model = Model()
247263
agents = [AgentDoTest(model) for _ in range(n)]
248-
agentset = AgentSet(agents, random=model.random)
264+
agentset = AgentSet(agents)
249265
for agent in agents:
250266
agent.agent_set = agentset
251267

@@ -255,7 +271,7 @@ def test_agentset_do_str():
255271
# setup
256272
model = Model()
257273
agents = [AgentDoTest(model) for _ in range(10)]
258-
agentset = AgentSet(agents, random=model.random)
274+
agentset = AgentSet(agents)
259275
for agent in agents:
260276
agent.agent_set = agentset
261277

@@ -267,7 +283,7 @@ def test_agentset_do_callable():
267283
"""Test AgentSet.do with callable."""
268284
model = Model()
269285
agents = [AgentTest(model) for _ in range(10)]
270-
agentset = AgentSet(agents, random=model.random)
286+
agentset = AgentSet(agents)
271287

272288
# Test callable with non-existent function
273289
with pytest.raises(AttributeError):
@@ -281,7 +297,7 @@ def test_agentset_do_callable():
281297
n = 10
282298
model = Model()
283299
agents = [AgentDoTest(model) for _ in range(n)]
284-
agentset = AgentSet(agents, random=model.random)
300+
agentset = AgentSet(agents)
285301
for agent in agents:
286302
agent.agent_set = agentset
287303

@@ -292,7 +308,7 @@ def test_agentset_do_callable():
292308
# setup again for lambda function tests
293309
model = Model()
294310
agents = [AgentDoTest(model) for _ in range(10)]
295-
agentset = AgentSet(agents, random=model.random)
311+
agentset = AgentSet(agents)
296312
for agent in agents:
297313
agent.agent_set = agentset
298314

@@ -310,7 +326,7 @@ def remove_function(agent):
310326
# setup again for actual function tests
311327
model = Model()
312328
agents = [AgentDoTest(model) for _ in range(n)]
313-
agentset = AgentSet(agents, random=model.random)
329+
agentset = AgentSet(agents)
314330
for agent in agents:
315331
agent.agent_set = agentset
316332

@@ -321,7 +337,7 @@ def remove_function(agent):
321337
# setup again for actual function tests
322338
model = Model()
323339
agents = [AgentDoTest(model) for _ in range(10)]
324-
agentset = AgentSet(agents, random=model.random)
340+
agentset = AgentSet(agents)
325341
for agent in agents:
326342
agent.agent_set = agentset
327343

@@ -386,7 +402,7 @@ def test_agentset_agg():
386402
agent.energy = i + 1
387403
agent.wealth = 10 * (i + 1)
388404

389-
agentset = AgentSet(agents, random=model.random)
405+
agentset = AgentSet(agents)
390406

391407
# Test min aggregation
392408
min_energy = agentset.agg("energy", min)
@@ -435,7 +451,7 @@ def __init__(self, model, age=None):
435451

436452
model = Model()
437453
agents = [TestAgentWithAttribute(model, age=i) for i in range(5)]
438-
agentset = AgentSet(agents, random=model.random)
454+
agentset = AgentSet(agents)
439455

440456
# Set a new attribute "health" and an existing attribute "age" for all agents
441457
agentset.set("health", 100).set("age", 50).set("status", "active")
@@ -454,7 +470,7 @@ def test_agentset_map_str():
454470
"""Test AgentSet.map with strings."""
455471
model = Model()
456472
agents = [AgentTest(model) for _ in range(10)]
457-
agentset = AgentSet(agents, random=model.random)
473+
agentset = AgentSet(agents)
458474

459475
with pytest.raises(AttributeError):
460476
agentset.do("non_existing_method")
@@ -467,7 +483,7 @@ def test_agentset_map_callable():
467483
"""Test AgentSet.map with callable."""
468484
model = Model()
469485
agents = [AgentTest(model) for _ in range(10)]
470-
agentset = AgentSet(agents, random=model.random)
486+
agentset = AgentSet(agents)
471487

472488
# Test callable with non-existent function
473489
with pytest.raises(AttributeError):
@@ -494,7 +510,7 @@ def test_method(self):
494510
self.called = True
495511

496512
agents = [TestAgentShuffleDo(model) for _ in range(100)]
497-
agentset = AgentSet(agents, random=model.random)
513+
agentset = AgentSet(agents)
498514

499515
# Test shuffle_do with a string method name
500516
agentset.shuffle_do("test_method")
@@ -544,7 +560,7 @@ def test_agentset_get_attribute():
544560
"""Test AgentSet.get for attributes."""
545561
model = Model()
546562
agents = [AgentTest(model) for _ in range(10)]
547-
agentset = AgentSet(agents, random=model.random)
563+
agentset = AgentSet(agents)
548564

549565
unique_ids = agentset.get("unique_id")
550566
assert unique_ids == [agent.unique_id for agent in agents]
@@ -558,7 +574,7 @@ def test_agentset_get_attribute():
558574
agent = AgentTest(model)
559575
agent.i = i**2
560576
agents.append(agent)
561-
agentset = AgentSet(agents, random=model.random)
577+
agentset = AgentSet(agents)
562578

563579
values = agentset.get(["unique_id", "i"])
564580

@@ -634,7 +650,7 @@ def get_unique_identifier(self):
634650

635651
model = Model()
636652
agents = [TestAgent(model) for _ in range(10)]
637-
agentset = AgentSet(agents, random=model.random)
653+
agentset = AgentSet(agents)
638654

639655
groups = agentset.groupby("even")
640656
assert len(groups.groups[True]) == 5

0 commit comments

Comments
 (0)