Skip to content

Commit d8a5ce2

Browse files
committed
fix: handle missing flowjax correctly
1 parent 42b4bcc commit d8a5ce2

File tree

4 files changed

+49
-17
lines changed

4 files changed

+49
-17
lines changed

python/nutpie/compile_stan.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
22
import tempfile
33
from dataclasses import dataclass, replace
4-
from functools import partial
54
from importlib.util import find_spec
65
from pathlib import Path
76
from typing import Any, Optional
@@ -12,7 +11,6 @@
1211

1312
from nutpie import _lib
1413
from nutpie.sample import CompiledModel
15-
from nutpie.transform_adapter import make_transform_adapter
1614

1715

1816
class _NumpyArrayEncoder(json.JSONEncoder):
@@ -45,13 +43,14 @@ def with_data(self, *, seed=None, **updates):
4543
else:
4644
data_json = None
4745

48-
kwargs = self._transform_adapt_args
49-
if kwargs is None:
50-
kwargs = {}
51-
make_adapter = partial(
52-
make_transform_adapter(**kwargs),
53-
logp_fn=None,
54-
)
46+
outer_kwargs = self._transform_adapt_args
47+
if outer_kwargs is None:
48+
outer_kwargs = {}
49+
50+
def make_adapter(*args, **kwargs):
51+
from nutpie.transform_adapter import make_transform_adapter
52+
53+
return make_transform_adapter(**outer_kwargs)(*args, **kwargs, logp_fn=None)
5554

5655
model = _lib.StanModel(self.library, seed, data_json, make_adapter)
5756
coords = self._coords

python/nutpie/compiled_pyfunc.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from nutpie import _lib # type: ignore
99
from nutpie.sample import CompiledModel
10-
from nutpie.transform_adapter import make_transform_adapter
1110

1211
SeedType = int
1312

@@ -67,13 +66,17 @@ def make_expand_func(seed1, seed2, chain):
6766
return partial(expand_fn, **self._shared_data)
6867

6968
if self._raw_logp_fn is not None:
70-
kwargs = self._transform_adapt_args
71-
if kwargs is None:
72-
kwargs = {}
73-
make_adapter = partial(
74-
make_transform_adapter(**kwargs),
75-
logp_fn=self._raw_logp_fn,
76-
)
69+
outer_kwargs = self._transform_adapt_args
70+
if outer_kwargs is None:
71+
outer_kwargs = {}
72+
73+
def make_adapter(*args, **kwargs):
74+
from nutpie.transform_adapter import make_transform_adapter
75+
76+
return make_transform_adapter(**outer_kwargs)(
77+
*args, **kwargs, logp_fn=self._raw_logp_fn
78+
)
79+
7780
else:
7881
make_adapter = None
7982

python/nutpie/transform_adapter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from functools import partial
2+
from importlib.util import find_spec
23
from typing import Callable
34
import time
45

6+
if find_spec("flowjax") is None:
7+
raise ImportError(
8+
"The 'flowjax' package is required to use normalizing flow adaptation."
9+
)
10+
511
from flowjax import bijections
612
from jaxtyping import ArrayLike, PyTree
713
import numpy as np

tests/test_stan.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,27 @@ def test_stan_model_data():
3838
trace = nutpie.sample(compiled_model)
3939
trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0)))
4040
trace.posterior.a # noqa: B018
41+
42+
43+
@pytest.mark.slow
44+
def test_stan_flow():
45+
model = """
46+
parameters {
47+
real a;
48+
real<lower=0> b;
49+
}
50+
model {
51+
a ~ normal(0, 1);
52+
b ~ normal(0, 1);
53+
}
54+
"""
55+
56+
compiled_model = nutpie.compile_stan_model(code=model).with_transform_adapt(
57+
num_layers=2,
58+
nn_width=4,
59+
num_diag_windows=6,
60+
)
61+
trace = nutpie.sample(
62+
compiled_model, transform_adapt=True, window_switch_freq=150, tune=600, chains=1
63+
)
64+
trace.posterior.a # noqa: B018

0 commit comments

Comments
 (0)