Skip to content

Commit 1dc542a

Browse files
committed
feat: add zarr_store argument to write trace while sampling
1 parent eaa1470 commit 1dc542a

File tree

8 files changed

+103
-8
lines changed

8 files changed

+103
-8
lines changed

Cargo.lock

Lines changed: 7 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ name = "_lib"
2121
crate-type = ["cdylib"]
2222

2323
[dependencies]
24-
nuts-rs = { version = "0.16.1", features = ["zarr", "arrow"] }
24+
nuts-rs = { version = "0.17.0", features = ["zarr", "arrow"] }
2525
numpy = "0.26.0"
2626
rand = "0.9.0"
2727
thiserror = "2.0.3"

docs/sampling-options.qmd

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ trace = nutpie.sample(
2525
tune=500, # Number of warmup draws for adaptation
2626
chains=6, # Number of independent chains
2727
cores=None, # Number chains that are allowed to run simultainiously
28-
seed=12345 # Random seed for reproducibility
28+
seed=12345 # Random seed for reproducibility
2929
)
3030
```
3131

@@ -143,6 +143,60 @@ trace = nutpie.sample(
143143
)
144144
```
145145

146+
## Zarr Storage (Experimental)
147+
148+
Nutpie includes experimental support for writing traces directly to zarr storage, which can be useful for large traces that don't fit in memory or for distributed storage scenarios. The zarr format provides efficient, chunked, compressed storage for multi-dimensional arrays.
149+
150+
### Basic Usage
151+
152+
You can write traces directly to zarr storage by providing a `zarr_store` parameter:
153+
154+
```python
155+
import nutpie
156+
import pymc as pm
157+
158+
with pm.Model() as model:
159+
pm.HalfNormal("a")
160+
161+
compiled = nutpie.compile_pymc_model(model, backend="numba")
162+
163+
# Create a local zarr store
164+
path = "trace.zarr"
165+
store = nutpie.zarr_store.LocalStore(path)
166+
167+
trace = nutpie.sample(
168+
compiled,
169+
chains=2,
170+
seed=123,
171+
draws=100,
172+
tune=100,
173+
zarr_store=store
174+
)
175+
```
176+
177+
### Memory Considerations
178+
179+
When using zarr storage, the trace object supports lazy loading:
180+
181+
```python
182+
# The trace is not loaded into memory by default
183+
posterior_data = trace.posterior.a # Lazy access
184+
185+
# Explicitly load the entire trace into memory (optional)
186+
loaded_trace = trace.load()
187+
posterior_data = loaded_trace.posterior.a # In-memory access
188+
```
189+
190+
### Available Store Types
191+
192+
Nutpie supports several zarr store backends:
193+
194+
- `nutpie.zarr_store.LocalStore(path)` - Local filesystem storage
195+
- `nutpie.zarr_store.S3Store(...)` - Amazon S3 storage
196+
- `nutpie.zarr_store.GCSStore(...)` - Google Cloud Storage
197+
- `nutpie.zarr_store.AzureStore(...)` - Azure Blob Storage
198+
- `nutpie.zarr_store.HTTPStore(...)` - HTTP-based storage
199+
146200
## Progress Monitoring
147201

148202
Customize the sampling progress display:

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ dependencies = [
2121
"pandas >= 2.0",
2222
"xarray >= 2025.01.2",
2323
"arviz >= 0.20.0",
24+
"obstore >= 0.8.0",
25+
"zarr >= 3.1.0",
2426
]
2527
dynamic = ["version"]
2628

python/nutpie/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
from nutpie.compile_pymc import compile_pymc_model
33
from nutpie.compile_stan import compile_stan_model
44
from nutpie.sample import sample
5+
from nutpie._lib import store as zarr_store
56

67
__version__: str = _lib.__version__
7-
__all__ = ["__version__", "compile_pymc_model", "compile_stan_model", "sample"]
8+
__all__ = [
9+
"__version__",
10+
"compile_pymc_model",
11+
"compile_stan_model",
12+
"sample",
13+
"zarr_store",
14+
]

python/nutpie/compile_stan.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pathlib import Path
55
from typing import Any, Optional
66

7-
import pandas as pd
87
from numpy.typing import NDArray
98

109
from nutpie import _lib

python/nutpie/sample.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -778,10 +778,12 @@ def sample(
778778
transform_adapt: bool, default=False
779779
Use the experimental transform adaptation algorithm
780780
during tuning.
781-
zarr_store: nutpie.store.Store
782-
A store created using nutpie.store to store the samples
781+
zarr_store: nutpie.zarr_store.*
782+
A store created using nutpie.zarr_store to store the samples
783783
in. If None (default), the samples will be stored in
784-
memory using an arrow table.
784+
memory using an arrow table. This can be used to write
785+
the trace directly into a zarr store, for instance
786+
on disk or to S3 or GCS.
785787
**kwargs
786788
Pass additional arguments to nutpie._lib.PySamplerArgs
787789

tests/test_pymc.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,28 @@ def test_deterministic_sampling_jax():
445445
compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")
446446
trace = nutpie.sample(compiled, chains=2, seed=123, draws=100, tune=100)
447447
return trace.posterior.a.values.ravel()
448+
449+
450+
@pytest.mark.pymc
451+
def test_zarr_store(tmp_path):
452+
with pm.Model() as model:
453+
pm.HalfNormal("a")
454+
455+
compiled = nutpie.compile_pymc_model(model, backend="numba")
456+
457+
path = tmp_path / "trace.zarr"
458+
path.mkdir()
459+
store = nutpie.zarr_store.LocalStore(str(path))
460+
trace = nutpie.sample(
461+
compiled, chains=2, seed=123, draws=100, tune=100, zarr_store=store
462+
)
463+
trace.load().posterior.a # noqa: B018
464+
465+
466+
@pytest.fixture
467+
def tmp_path():
468+
import tempfile
469+
from pathlib import Path
470+
471+
with tempfile.TemporaryDirectory() as tmpdirname:
472+
yield Path(tmpdirname)

0 commit comments

Comments
 (0)