-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Fix: Auto-increment seed across batch_run iterations #2841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 16 commits
728456c
c03c6fa
37a0839
7f456af
10136f9
904e796
7972638
529f3ac
7b6eaef
e53e16b
4113a11
94c44e3
0778ff4
5db0058
27a777d
0a4ea99
4726920
7ba67df
088425a
435399c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,29 +28,35 @@ | |
|
|
||
| """ | ||
|
|
||
| import inspect | ||
| import itertools | ||
| import multiprocessing | ||
| from collections.abc import Iterable, Mapping | ||
| import warnings | ||
| from collections.abc import Iterable, Mapping, Sequence | ||
| from functools import partial | ||
| from multiprocessing import Pool | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| from tqdm.auto import tqdm | ||
|
|
||
| from mesa.model import Model | ||
|
|
||
| multiprocessing.set_start_method("spawn", force=True) | ||
|
|
||
| SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence | ||
|
|
||
|
|
||
| def batch_run( | ||
| model_cls: type[Model], | ||
| parameters: Mapping[str, Any | Iterable[Any]], | ||
| # We still retain the Optional[int] because users may set it to None (i.e. use all CPUs) | ||
| number_processes: int | None = 1, | ||
| iterations: int = 1, | ||
| iterations: int | None = None, | ||
| data_collection_period: int = -1, | ||
| max_steps: int = 1000, | ||
| display_progress: bool = True, | ||
| rng: SeedLike | Iterable[SeedLike] | None = None, | ||
| ) -> list[dict[str, Any]]: | ||
| """Batch run a mesa model with a set of parameter values. | ||
|
|
||
|
|
@@ -62,6 +68,7 @@ def batch_run( | |
| data_collection_period (int, optional): Number of steps after which data gets collected, by default -1 (end of episode) | ||
| max_steps (int, optional): Maximum number of model steps after which the model halts, by default 1000 | ||
| display_progress (bool, optional): Display batch run process, by default True | ||
| rng : a valid value or iterable of values for seeding the random number generator in the model | ||
|
|
||
| Returns: | ||
| List[Dict[str, Any]] | ||
|
|
@@ -70,11 +77,34 @@ def batch_run( | |
| batch_run assumes the model has a `datacollector` attribute that has a DataCollector object initialized. | ||
|
|
||
| """ | ||
| if iterations is not None and rng is not None: | ||
| raise ValueError( | ||
| "you cannot use both iterations and rng at the same time. Please only use rng." | ||
| ) | ||
| if iterations is not None: | ||
| warnings.warn( | ||
| "iterations is deprecated, please use rng instead", | ||
quaquel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| rng = [ | ||
| None, | ||
| ] * iterations | ||
quaquel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not isinstance(rng, Iterable): | ||
| rng = [rng] | ||
|
|
||
| # establish to use seed or rng as name for parameter | ||
| model_parameters = inspect.signature(Model).parameters | ||
| rng_kwarg_name = "rng" | ||
| if "seed" in model_parameters: | ||
| rng_kwarg_name = "seed" | ||
|
Comment on lines
+95
to
+99
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This behavior should be concisely explained in the batch_run docstring and tutorial
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why? Using both seed and rng on a model already gives an error. This just ensures that the seed is set to the keyword argument specified by the user in their model class, whether it is seed or rng.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair, that's true. Give me a moment to get up to speed with it myself. |
||
|
|
||
| runs_list = [] | ||
| run_id = 0 | ||
| for iteration in range(iterations): | ||
| for i, rng_i in enumerate(rng): | ||
| for kwargs in _make_model_kwargs(parameters): | ||
| runs_list.append((run_id, iteration, kwargs)) | ||
| kwargs[rng_kwarg_name] = rng_i | ||
| runs_list.append((run_id, i, kwargs)) | ||
| run_id += 1 | ||
|
|
||
| process_func = partial( | ||
|
|
@@ -170,6 +200,7 @@ def _model_run_func( | |
| Return model_data, agent_data from the reporters | ||
| """ | ||
| run_id, iteration, kwargs = run | ||
|
|
||
| model = model_cls(**kwargs) | ||
| while model.running and model.steps <= max_steps: | ||
| model.step() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.