diff --git a/pyproject.toml b/pyproject.toml index e037d90..bd268d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,12 +28,13 @@ Homepage = "https://pymc-devs.github.io/nutpie/" Repository = "https://github.com/pymc-devs/nutpie" [project.optional-dependencies] -stan = ["bridgestan >= 2.6.1"] +stan = ["bridgestan >= 2.6.1", "stanio >= 0.5.1"] pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"] pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"] nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"] dev = [ "bridgestan >= 2.6.1", + "stanio >= 0.5.1", "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", @@ -43,6 +44,7 @@ dev = [ ] all = [ "bridgestan >= 2.6.1", + "stanio >= 0.5.1", "pymc >= 5.20.1", "numba >= 0.60.0", "jax >= 0.4.27", diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index a11a780..9e65b1e 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -1,11 +1,9 @@ -import json import tempfile from dataclasses import dataclass, replace from importlib.util import find_spec from pathlib import Path from typing import Any, Optional -import numpy as np import pandas as pd from numpy.typing import NDArray @@ -13,13 +11,6 @@ from nutpie.sample import CompiledModel -class _NumpyArrayEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, np.ndarray): - return obj.tolist() - return json.JSONEncoder.default(self, obj) - - @dataclass(frozen=True) class CompiledStanModel(CompiledModel): _coords: Optional[dict[str, Any]] @@ -39,7 +30,16 @@ def with_data(self, *, seed=None, **updates): data.update(updates) if data is not None: - data_json = json.dumps(data, cls=_NumpyArrayEncoder) + if find_spec("stanio") is None: + raise ImportError( + "stanio is not installed in the current environment. " + "Please install it with something like " + "'pip install stanio' or 'pip install nutpie[stan]'." + ) + + import stanio + + data_json = stanio.dump_stan_json(data) else: data_json = None @@ -136,7 +136,7 @@ def compile_stan_model( raise ImportError( "BridgeStan is not installed in the current environment. " "Please install it with something like " - "'pip install bridgestan'." + "'pip install bridgestan' or 'pip install nutpie[stan]'." ) import bridgestan diff --git a/tests/test_stan.py b/tests/test_stan.py index e55778e..89201c5 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -31,7 +31,7 @@ def test_stan_model(): def test_stan_model_data(): model = """ data { - real x; + complex x; } parameters { real a; @@ -44,7 +44,7 @@ def test_stan_model_data(): compiled_model = nutpie.compile_stan_model(code=model) with pytest.raises(RuntimeError): trace = nutpie.sample(compiled_model) - trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0))) + trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0j))) trace.posterior.a # noqa: B018