Skip to content

Commit 96f9eed

Browse files
committed
removed thinning, updated docstrings and refs
1 parent d7be978 commit 96f9eed

File tree

3 files changed

+4
-6
lines changed

3 files changed

+4
-6
lines changed

external_tests/helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _numpyro_noncentered_guide(J, sigma, y=None):
132132

133133

134134
def numpyro_schools_model(data, draws, chains):
135-
"""Noncentered eight schools implementation in NumPyro."""
135+
"""Non-centered eight schools implementation in NumPyro."""
136136
from jax.random import PRNGKey
137137
from numpyro.infer import MCMC, NUTS
138138

@@ -156,7 +156,7 @@ def numpyro_schools_model(data, draws, chains):
156156

157157

158158
def numpyro_schools_model_svi(data, draws, chains):
159-
"""Centered eight schools implementation in NumPyro."""
159+
"""Non-centered eight schools implementation in NumPyro."""
160160
from jax.random import PRNGKey
161161
from numpyro.infer import SVI, Trace_ELBO, init_to_sample
162162
from numpyro.infer.autoguide import AutoNormal
@@ -169,7 +169,7 @@ def numpyro_schools_model_svi(data, draws, chains):
169169

170170

171171
def numpyro_schools_model_svi_custom_guide(data, draws, chains):
172-
"""Centered eight schools implementation in NumPyro."""
172+
"""Non-centered eight schools implementation in NumPyro."""
173173
from jax.random import PRNGKey
174174
from numpyro.infer import SVI, Trace_ELBO
175175
from numpyro.optim import Adam

src/arviz_base/io_numpyro.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def __init__(
2222
model_args=None,
2323
model_kwargs=None,
2424
num_samples: int = 1000,
25-
thinning: int = 1,
2625
):
2726
import jax
2827
import numpyro
@@ -32,7 +31,7 @@ def __init__(
3231
self._args = model_args or tuple()
3332
self._kwargs = model_kwargs or dict()
3433
self.num_samples = num_samples
35-
self.thinning = thinning
34+
self.thinning = 1
3635
self.num_chains = 0
3736
self.sample_dims = ["samples"]
3837
self.kind = "svi"

src/arviz_base/io_numpyro.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ class SVIWrapper:
2424
model_args=...,
2525
model_kwargs=...,
2626
num_samples: int = ...,
27-
thinning: int = ...,
2827
) -> None: ...
2928
def get_samples(self, seed=..., **kwargs) -> None: ...
3029
@property

0 commit comments

Comments
 (0)