Skip to content

Commit 4b0e1bf

Browse files
committed
regenerate stub files after rebase
1 parent 0d4fda6 commit 4b0e1bf

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

src/arviz_base/__init__.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ from arviz_base.datasets import (
2525
from arviz_base.io_cmdstanpy import from_cmdstanpy
2626
from arviz_base.io_dict import from_dict
2727
from arviz_base.io_emcee import from_emcee
28-
from arviz_base.io_numpyro import from_numpyro
28+
from arviz_base.io_numpyro import from_numpyro, from_numpyro_svi
2929
from arviz_base.io_pystan import from_pystan
3030
from arviz_base.rcparams import rc_context, rcParams
3131
from arviz_base.reorg import (
@@ -55,6 +55,7 @@ __all__ = [
5555
"from_dict",
5656
"from_emcee",
5757
"from_numpyro",
58+
"from_numpyro_svi",
5859
"rc_context",
5960
"rcParams",
6061
"extract",

src/arviz_base/io_numpyro.pyi

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@ from arviz_base.base import dict_to_dataset, requires
1414
from arviz_base.rcparams import rc_context, rcParams
1515
from arviz_base.utils import expand_dims
1616

17+
class SVIWrapper:
18+
def __init__(
19+
self,
20+
svi,
21+
*,
22+
svi_result,
23+
model_args=...,
24+
model_kwargs=...,
25+
num_samples: int = ...,
26+
) -> None: ...
27+
def get_samples(self, seed=..., **kwargs) -> None: ...
28+
@property
29+
def sampler(self) -> None: ...
30+
def get_extra_fields(self, **kwargs) -> None: ...
31+
1732
def _add_dims(
1833
dims_a: dict[str, list[str]], dims_b: dict[str, list[str]]
1934
) -> dict[str, list[str]]: ...
@@ -46,7 +61,7 @@ class NumPyroConverter:
4661
extra_event_dims: dict | None = ...,
4762
num_chains: int = ...,
4863
) -> None: ...
49-
def _get_model_trace(self, model, args, kwargs, key) -> None: ...
64+
def _get_model_trace(self, model, model_args, model_kwargs, key) -> None: ...
5065
def posterior_to_xarray(self) -> None: ...
5166
def sample_stats_to_xarray(self) -> None: ...
5267
def log_likelihood_to_xarray(self) -> None: ...
@@ -77,3 +92,23 @@ def from_numpyro(
7792
extra_event_dims: dict | None = ...,
7893
num_chains: int = ...,
7994
) -> DataTree: ...
95+
def from_numpyro_svi(
96+
svi: numpyro.infer.svi.SVI,
97+
*,
98+
svi_result: numpyro.infer.svi.SVIRunResult,
99+
model_args: tuple | None = ...,
100+
model_kwargs: dict | None = ...,
101+
prior: dict | None = ...,
102+
posterior_predictive: dict | None = ...,
103+
predictions: dict | None = ...,
104+
constant_data: dict | None = ...,
105+
predictions_constant_data: dict | None = ...,
106+
log_likelihood=...,
107+
index_origin: int | None = ...,
108+
coords: dict | None = ...,
109+
dims: dict[str, list[str]] | None = ...,
110+
pred_dims: dict | None = ...,
111+
extra_event_dims: dict | None = ...,
112+
model=...,
113+
num_samples: int = ...,
114+
) -> DataTree: ...

0 commit comments

Comments
 (0)