Skip to content

Commit 5728ad2

Browse files
authored
implement fix + tests (#5915)
1 parent d9197ef commit 5728ad2

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pymc/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,11 @@ def __init__(
542542
name="",
543543
coords=None,
544544
check_bounds=True,
545+
*,
546+
aesara_config=None,
547+
model=None,
545548
):
549+
del aesara_config, model # used in __new__
546550
self.name = self._validate_name(name)
547551
self.check_bounds = check_bounds
548552

pymc/tests/test_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,3 +1014,20 @@ def test_compile_fn():
10141014
result_expect = func(state)
10151015

10161016
np.testing.assert_allclose(result_compute, result_expect)
1017+
1018+
1019+
def test_model_aesara_config():
1020+
assert aesara.config.mode != "JAX"
1021+
with pm.Model(aesara_config=dict(mode="JAX")) as model:
1022+
assert aesara.config.mode == "JAX"
1023+
assert aesara.config.mode != "JAX"
1024+
1025+
1026+
def test_model_parent_set_programmatically():
1027+
with pm.Model() as model:
1028+
x = pm.Normal("x")
1029+
1030+
with pm.Model(model=model):
1031+
y = pm.Normal("y")
1032+
1033+
assert "y" in model.named_vars

0 commit comments

Comments
 (0)