@@ -16,16 +16,15 @@ class SVIWrapper:
1616
1717 def __init__ (
1818 self ,
19- guide ,
19+ svi ,
2020 * ,
2121 svi_result ,
2222 model_args = None ,
2323 model_kwargs = None ,
2424 num_samples : int = 1000 ,
25- model = None ,
2625 thinning : int = 1 ,
2726 ):
28- self .guide = guide
27+ self .svi = svi
2928 self .svi_result = svi_result
3029 self ._args = model_args or tuple ()
3130 self ._kwargs = model_kwargs or dict ()
@@ -34,7 +33,6 @@ def __init__(
3433 self .num_chains = 0
3534 self .sample_dims = ["samples" ]
3635 self .kind = "svi"
37- self .model = model
3836
3937 def get_samples (self , seed = None , ** kwargs ):
4038 """Mimics mcmc.get_samples()."""
@@ -43,8 +41,8 @@ def get_samples(self, seed=None, **kwargs):
4341 from numpyro .infer .autoguide import AutoGuide
4442
4543 key = jax .random .PRNGKey (seed or 0 )
46- if isinstance (self .guide , AutoGuide ):
47- return self .guide .sample_posterior (
44+ if isinstance (self .svi . guide , AutoGuide ):
45+ return self .svi . guide .sample_posterior (
4846 key ,
4947 self .svi_result .params ,
5048 * self ._args ,
@@ -53,7 +51,7 @@ def get_samples(self, seed=None, **kwargs):
5351 )
5452 # if a custom guide is provided, sample by hand
5553 predictive = Predictive (
56- self .guide , params = self .svi_result .params , num_samples = self .num_samples
54+ self .svi . guide , params = self .svi_result .params , num_samples = self .num_samples
5755 )
5856 samples = predictive (key , * self ._args , ** self ._kwargs )
5957 return samples
@@ -70,7 +68,7 @@ def __init__(self, model):
7068 def model (self ):
7169 return self ._model
7270
73- return Sampler (getattr (self .guide , "model" , self .model ))
71+ return Sampler (getattr (self .svi . guide , "model" , self . svi .model ))
7472
7573 def get_extra_fields (self , ** kwargs ):
7674 """Mimics mcmc.get_extra_fields()."""
@@ -623,7 +621,7 @@ def from_numpyro(
623621
624622
625623def from_numpyro_svi (
626- guide ,
624+ svi ,
627625 svi_result ,
628626 * ,
629627 model_args = None ,
@@ -663,8 +661,8 @@ def from_numpyro_svi(
663661
664662 Parameters
665663 ----------
666- guide : numpyro.infer.autoguide.AutoGuide or callable
667- Guide function for a numpyro SVI model. Can be an autoguide or custom guide .
664+ guide : numpyro.infer.svi.SVI
665+ Numpyro SVI instance used for fitting the model .
668666 svi_result : numpyro.infer.svi.SVIRunResult
669667 SVI results from a fitted model.
670668 model_args : tuple, optional
@@ -694,20 +692,17 @@ def from_numpyro_svi(
694692 their coordinates.
695693 num_chains : int, default 1
696694 Number of chains used for sampling. Ignored if posterior is present.
697- model : callable, optional
698- Model function, only needed for a custom guide function
699695
700696 Returns
701697 -------
702698 DataTree
703699 """
704700 posterior = SVIWrapper (
705- guide ,
701+ svi ,
706702 svi_result = svi_result ,
707703 model_args = model_args ,
708704 model_kwargs = model_kwargs ,
709705 num_samples = num_samples ,
710- model = model ,
711706 )
712707 with rc_context (rc = {"data.sample_dims" : ["samples" ]}):
713708 return NumPyroConverter (
0 commit comments