Skip to content

Commit 68ff97b

Browse files
committed
feat: use new nuts-rs storage interface
1 parent f31f47c commit 68ff97b

19 files changed

+4729
-1895
lines changed

Cargo.lock

Lines changed: 2427 additions & 459 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ license = "MIT"
1010
repository = "https://github.com/pymc-devs/nutpie"
1111
keywords = ["statistics", "bayes"]
1212
description = "Python wrapper for nuts-rs -- a NUTS sampler written in Rust."
13-
rust-version = "1.76"
13+
rust-version = "1.90"
1414

1515
[features]
1616
extension-module = ["pyo3/extension-module"]
@@ -21,23 +21,28 @@ name = "_lib"
2121
crate-type = ["cdylib"]
2222

2323
[dependencies]
24-
nuts-rs = "0.16.1"
24+
nuts-rs = { version = "0.16.1", features = ["zarr", "arrow"] }
2525
numpy = "0.26.0"
2626
rand = "0.9.0"
2727
thiserror = "2.0.3"
2828
rand_chacha = "0.9.0"
29-
rayon = "1.10.0"
30-
# Keep arrow in sync with nuts-rs requirements
31-
arrow = { version = "56.1.0", default-features = false, features = ["ffi"] }
29+
rayon = "1.11.0"
3230
anyhow = "1.0.72"
3331
itertools = "0.14.0"
34-
bridgestan = "2.6.1"
32+
bridgestan = "2.7.0"
3533
rand_distr = "0.5.0"
36-
smallvec = "1.14.0"
34+
smallvec = "1.15.0"
3735
upon = { version = "0.10.0", default-features = false, features = [] }
3836
time-humanize = { version = "0.1.3", default-features = false }
3937
indicatif = "0.18.0"
4038
tch = { version = "0.20.0", optional = true }
39+
pyo3-object_store = "0.6.0"
40+
# Keep zarrs crates in sync with nuts-rs requirements
41+
zarrs = { version = "0.22.2", features = ["async"] }
42+
zarrs_object_store = "0.5.0"
43+
tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread"] }
44+
pyo3-arrow = "0.12.0"
45+
arrow = { version = "56.2.0", features = ["json"] }
4146

4247
[dependencies.pyo3]
4348
version = "0.26.0"

pyproject.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ classifiers = [
1717

1818
dependencies = [
1919
"pyarrow >= 12.0.0",
20+
"arro3-core >= 0.6.0",
2021
"pandas >= 2.0",
2122
"xarray >= 2025.01.2",
2223
"arviz >= 0.20.0",
@@ -28,12 +29,12 @@ Homepage = "https://pymc-devs.github.io/nutpie/"
2829
Repository = "https://github.com/pymc-devs/nutpie"
2930

3031
[project.optional-dependencies]
31-
stan = ["bridgestan >= 2.6.1", "stanio >= 0.5.1"]
32+
stan = ["bridgestan >= 2.7.0", "stanio >= 0.5.1"]
3233
pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"]
3334
pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"]
3435
nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"]
3536
dev = [
36-
"bridgestan >= 2.6.1",
37+
"bridgestan >= 2.7.0",
3738
"stanio >= 0.5.1",
3839
"pymc >= 5.20.1",
3940
"numba >= 0.60.0",
@@ -44,7 +45,7 @@ dev = [
4445
"pytest-arraydiff",
4546
]
4647
all = [
47-
"bridgestan >= 2.6.1",
48+
"bridgestan >= 2.7.0",
4849
"stanio >= 0.5.1",
4950
"pymc >= 5.20.1",
5051
"numba >= 0.60.0",
@@ -76,7 +77,7 @@ features = ["pyo3/extension-module"]
7677

7778
[tool.pytest.ini_options]
7879
markers = [
79-
"flow: tests for normalizing flows",
80-
"stan: tests for Stan models",
81-
"pymc: tests for PyMC models",
80+
"flow: tests for normalizing flows",
81+
"stan: tests for Stan models",
82+
"pymc: tests for PyMC models",
8283
]

python/nutpie/compile_pymc.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class CompiledPyMCModel(CompiledModel):
111111
_n_dim: int
112112
_shapes: dict[str, tuple[int, ...]]
113113
_coords: Optional[dict[str, Any]]
114+
_transform_adapt_args: dict | None = None
114115

115116
@property
116117
def n_dim(self):
@@ -146,13 +147,14 @@ def with_data(self, **updates):
146147
user_data=user_data,
147148
)
148149

149-
def _make_sampler(self, settings, init_mean, cores, progress_type):
150+
def _make_sampler(self, settings, init_mean, cores, progress_type, store):
150151
model = self._make_model(init_mean)
151152
return _lib.PySampler.from_pymc(
152153
settings,
153154
cores,
154155
model,
155156
progress_type,
157+
store,
156158
)
157159

158160
def _make_model(self, init_mean):
@@ -164,24 +166,46 @@ def _make_model(self, init_mean):
164166
self,
165167
)
166168
logp_fn = _lib.LogpFunc(
167-
self.n_dim,
168169
self.compiled_logp_func.address,
169170
self.user_data.ctypes.data,
170171
self,
171172
)
172173

173-
var_sizes = [prod(shape) for shape in self.shape_info[2]]
174174
var_names = self.shape_info[0]
175175

176+
coords = self._coords.copy() if self._coords is not None else {}
177+
dim_sizes = {name: len(vals) for name, vals in coords.items()}
178+
dims = self.dims.copy() if self.dims is not None else {}
179+
var_types = ["float64"] * len(var_names)
180+
var_shapes = self.shape_info[2]
181+
182+
variables = _lib.PyVariable.new_variables(
183+
var_names, var_types, var_shapes, dim_sizes, dims
184+
)
185+
186+
outer_kwargs = self._transform_adapt_args
187+
if outer_kwargs is None:
188+
outer_kwargs = {}
189+
190+
def make_adapter(*args, **kwargs):
191+
from nutpie.transform_adapter import make_transform_adapter
192+
193+
return make_transform_adapter(**outer_kwargs)(*args, **kwargs, logp_fn=None)
194+
176195
return _lib.PyMcModel(
177-
self.n_dim,
178196
logp_fn,
179197
expand_fn,
198+
variables,
199+
self.n_dim,
200+
dim_sizes,
201+
coords,
180202
self.initial_point_func,
181-
var_sizes,
182-
var_names,
203+
make_adapter,
183204
)
184205

206+
def with_transform_adapt(self, **kwargs):
207+
return dataclasses.replace(self, _transform_adapt_args=kwargs)
208+
185209

186210
def update_user_data(user_data, user_data_storage):
187211
user_data = user_data[()]

python/nutpie/compile_stan.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,25 @@ def make_adapter(*args, **kwargs):
5252

5353
return make_transform_adapter(**outer_kwargs)(*args, **kwargs, logp_fn=None)
5454

55-
model = _lib.StanModel(self.library, seed, data_json, make_adapter)
55+
coords = self._coords
56+
if coords is None:
57+
coords = {}
58+
coords = coords.copy()
59+
60+
dims = self.dims
61+
if dims is None:
62+
dims = {}
63+
dims = dims.copy()
64+
dim_sizes = {name: len(dim) for name, dim in coords.items()}
65+
66+
model = _lib.StanModel(
67+
self.library, dim_sizes, dims, coords, seed, data_json, make_adapter
68+
)
5669
coords = self._coords
5770
if coords is None:
5871
coords = {}
5972
else:
6073
coords = coords.copy()
61-
coords["unconstrained_parameter"] = pd.Index(model.param_unc_names())
6274

6375
return CompiledStanModel(
6476
_coords=coords,
@@ -93,13 +105,14 @@ def _make_model(self, init_mean):
93105
return self.with_data().model
94106
return self.model
95107

96-
def _make_sampler(self, settings, init_mean, cores, progress_type):
108+
def _make_sampler(self, settings, init_mean, cores, progress_type, store):
97109
model = self._make_model(init_mean)
98110
return _lib.PySampler.from_stan(
99111
settings,
100112
cores,
101113
model,
102114
progress_type,
115+
store,
103116
)
104117

105118
@property

python/nutpie/compiled_pyfunc.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class PyFuncModel(CompiledModel):
1919
_shared_data: dict[str, Any]
2020
_n_dim: int
2121
_variables: list[_lib.PyVariable]
22+
_dim_sizes: dict[str, int]
2223
_coords: dict[str, Any]
2324
_raw_logp_fn: Callable | None
2425
_transform_adapt_args: dict | None = None
@@ -47,13 +48,14 @@ def with_data(self, **updates):
4748
def with_transform_adapt(self, **kwargs):
4849
return dataclasses.replace(self, _transform_adapt_args=kwargs)
4950

50-
def _make_sampler(self, settings, init_mean, cores, progress_type):
51+
def _make_sampler(self, settings, init_mean, cores, progress_type, store):
5152
model = self._make_model(init_mean)
5253
return _lib.PySampler.from_pyfunc(
5354
settings,
5455
cores,
5556
model,
5657
progress_type,
58+
store,
5759
)
5860

5961
def _make_model(self, init_mean):
@@ -85,6 +87,8 @@ def make_adapter(*args, **kwargs):
8587
make_expand_func,
8688
self._variables,
8789
self.n_dim,
90+
dim_sizes=self._dim_sizes,
91+
coords=self._coords,
8892
init_point_func=self._make_initial_points,
8993
transform_adapter=make_adapter,
9094
)
@@ -105,30 +109,30 @@ def from_pyfunc(
105109
make_transform_adapter=None,
106110
raw_logp_fn=None,
107111
):
108-
variables = []
109-
for name, shape, dtype in zip(
110-
expanded_names, expanded_shapes, expanded_dtypes, strict=True
111-
):
112-
shape = _lib.TensorShape(list(shape))
113-
if dtype == np.float64:
114-
dtype = _lib.ExpandDtype.float64_array(shape)
115-
elif dtype == np.float32:
116-
dtype = _lib.ExpandDtype.float32_array(shape)
117-
elif dtype == np.int64:
118-
dtype = _lib.ExpandDtype.int64_array(shape)
119-
variables.append(_lib.PyVariable(name, dtype))
120-
121112
if coords is None:
122113
coords = {}
123114
if dims is None:
124115
dims = {}
125116
if shared_data is None:
126117
shared_data = {}
127118

119+
coords = coords.copy()
120+
121+
dim_sizes = {k: len(v) for k, v in coords.items()}
122+
shapes = [tuple(shape) for shape in expanded_shapes]
123+
variables = _lib.PyVariable.new_variables(
124+
expanded_names,
125+
[str(dtype) for dtype in expanded_dtypes],
126+
shapes,
127+
dim_sizes,
128+
dims,
129+
)
130+
128131
return PyFuncModel(
129132
_n_dim=ndim,
130133
dims=dims,
131134
_coords=coords,
135+
_dim_sizes=dim_sizes,
132136
_make_logp_func=make_logp_fn,
133137
_make_expand_func=make_expand_fn,
134138
_make_initial_points=make_initial_point_fn,

0 commit comments

Comments
 (0)