Skip to content

Commit e9c0530

Browse files
authored
Fix #2452 - Check model parameters on startup, remove reactive seed (#2454)
1 parent a49e40d commit e9c0530

File tree

2 files changed

+74
-20
lines changed

2 files changed

+74
-20
lines changed

mesa/visualization/solara_viz.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import asyncio
2727
import copy
28+
import inspect
2829
from collections.abc import Callable
2930
from typing import TYPE_CHECKING, Literal
3031

@@ -299,9 +300,12 @@ def ModelCreator(model, model_params, seed=1):
299300
- The component provides an interface for adjusting user-defined parameters and reseeding the model.
300301
301302
"""
302-
user_params, fixed_params = split_model_params(model_params)
303+
solara.use_effect(
304+
lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
305+
[model.value],
306+
)
303307

304-
reactive_seed = solara.use_reactive(seed)
308+
user_params, fixed_params = split_model_params(model_params)
305309

306310
model_parameters, set_model_parameters = solara.use_state(
307311
{
@@ -310,29 +314,35 @@ def ModelCreator(model, model_params, seed=1):
310314
}
311315
)
312316

313-
def do_reseed():
314-
"""Update the random seed for the model."""
315-
reactive_seed.value = model.value.random.random()
316-
317317
def on_change(name, value):
318-
set_model_parameters({**model_parameters, name: value})
318+
new_model_parameters = {**model_parameters, name: value}
319+
model.value = model.value.__class__(**new_model_parameters)
320+
set_model_parameters(new_model_parameters)
319321

320-
def create_model():
321-
model.value = model.value.__class__(**model_parameters)
322-
model.value._seed = reactive_seed.value
322+
UserInputs(user_params, on_change=on_change)
323323

324-
solara.use_effect(create_model, [model_parameters, reactive_seed.value])
325324

326-
with solara.Row(justify="space-between"):
327-
solara.InputText(
328-
label="Seed",
329-
value=reactive_seed,
330-
continuous_update=True,
331-
)
325+
def _check_model_params(init_func, model_params):
326+
"""Check if model parameters are valid for the model's initialization function.
332327
333-
solara.Button(label="Reseed", color="primary", on_click=do_reseed)
328+
Args:
329+
init_func: Model initialization function
330+
model_params: Dictionary of model parameters
334331
335-
UserInputs(user_params, on_change=on_change)
332+
Raises:
333+
ValueError: If a parameter is not valid for the model's initialization function
334+
"""
335+
model_parameters = inspect.signature(init_func).parameters
336+
for name in model_parameters:
337+
if (
338+
model_parameters[name].default == inspect.Parameter.empty
339+
and name not in model_params
340+
and name != "self"
341+
):
342+
raise ValueError(f"Missing required model parameter: {name}")
343+
for name in model_params:
344+
if name not in model_parameters:
345+
raise ValueError(f"Invalid model parameter: {name}")
336346

337347

338348
@solara.component

tests/test_solara_viz.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33
import unittest
44

55
import ipyvuetify as vw
6+
import pytest
67
import solara
78

89
import mesa
910
import mesa.visualization.components.altair_components
1011
import mesa.visualization.components.matplotlib_components
1112
from mesa.visualization.components.matplotlib_components import make_mpl_space_component
12-
from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs
13+
from mesa.visualization.solara_viz import (
14+
Slider,
15+
SolaraViz,
16+
UserInputs,
17+
_check_model_params,
18+
)
1319

1420

1521
class TestMakeUserInput(unittest.TestCase): # noqa: D101
@@ -152,3 +158,41 @@ def test_slider(): # noqa: D103
152158
assert not slider_int.is_float_slider
153159
slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float)
154160
assert slider_dtype_float.is_float_slider
161+
162+
163+
def test_model_param_checks(): # noqa: D103
164+
class ModelWithOptionalParams:
165+
def __init__(self, required_param, optional_param=10):
166+
pass
167+
168+
class ModelWithOnlyRequired:
169+
def __init__(self, param1, param2):
170+
pass
171+
172+
# Test that optional params can be omitted
173+
_check_model_params(ModelWithOptionalParams.__init__, {"required_param": 1})
174+
175+
# Test that optional params can be provided
176+
_check_model_params(
177+
ModelWithOptionalParams.__init__, {"required_param": 1, "optional_param": 5}
178+
)
179+
180+
# Test invalid parameter name raises ValueError
181+
with pytest.raises(ValueError, match="Invalid model parameter: invalid_param"):
182+
_check_model_params(
183+
ModelWithOptionalParams.__init__, {"required_param": 1, "invalid_param": 2}
184+
)
185+
186+
# Test missing required parameter raises ValueError
187+
with pytest.raises(ValueError, match="Missing required model parameter: param2"):
188+
_check_model_params(ModelWithOnlyRequired.__init__, {"param1": 1})
189+
190+
# Test passing extra parameters raises ValueError
191+
with pytest.raises(ValueError, match="Invalid model parameter: extra"):
192+
_check_model_params(
193+
ModelWithOnlyRequired.__init__, {"param1": 1, "param2": 2, "extra": 3}
194+
)
195+
196+
# Test empty params dict raises ValueError if required params
197+
with pytest.raises(ValueError, match="Missing required model parameter"):
198+
_check_model_params(ModelWithOnlyRequired.__init__, {})

0 commit comments

Comments
 (0)