Skip to content

Commit 430c344

Browse files
authored
Pass coords argument into model factory (#282)
1 parent c00c368 commit 430c344

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

pymc_experimental/model/model_api.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ def as_model(*model_args, **model_kwargs):
77
R"""
88
Decorator to provide context to PyMC models declared in a function.
99
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
10+
Additionally, a coords argument is added to the function so coords can be changed during function invocation
1011
1112
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
1213
@@ -32,12 +33,21 @@ def basic_model():
3233
m = basic_model()
3334
pm.sample(model=m)
3435
36+
# alternative way to use functional API
37+
@pmx.as_model()
38+
def basic_model():
39+
pm.Normal("x", 0., 1., dims="obs")
40+
41+
m = basic_model(coords={"obs": ["a", "b"]})
42+
pm.sample(model=m)
43+
3544
"""
3645

3746
def decorator(f):
3847
@wraps(f)
3948
def make_model(*args, **kwargs):
40-
with Model(*model_args, **model_kwargs) as m:
49+
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
50+
with Model(*model_args, coords=coords, **model_kwargs) as m:
4151
f(*args, **kwargs)
4252
return m
4353

pymc_experimental/tests/model/test_model_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,11 @@ def model_wrapped():
1919

2020
mw = model_wrapped()
2121

22+
@pmx.as_model()
23+
def model_wrapped2():
24+
pm.Normal("x", 0.0, 1.0, dims="obs")
25+
26+
mw2 = model_wrapped2(coords=coords)
27+
2228
np.testing.assert_equal(model.point_logps(), mw.point_logps())
29+
np.testing.assert_equal(mw.point_logps(), mw2.point_logps())

0 commit comments

Comments
 (0)