@@ -571,3 +571,72 @@ def _run_inference(self, model, svi, guide_fn):
571571 mcmc = MCMC (NUTS (model ), num_warmup = 10 , num_samples = 10 )
572572 mcmc .run (PRNGKey (0 ))
573573 return {"posterior" : mcmc }
574+
575+
576+ class TestSVIWrapper :
577+ # @pytest.fixture(scope="class", params=["numpyro", "numpyro_svi"])
578+ # def data(self, request):
579+ # from numpyro.infer.svi import SVI, SVIRunResult
580+ # from numpyro.infer import Trace_ELBO
581+ # from numpyro.optim import Adam
582+ # from numpyro import distributions as dist
583+
584+ # def model():
585+ # numpyro.sample("alpha", dist.Normal(0, 10))
586+
587+ # return {
588+ # "numpyro": (numpyro.infer.SVI()
589+ # }[request.param]
590+
591+ # return Data
592+
593+ @pytest .fixture (scope = "class" , params = ["numpyro_svi" , "numpyro_svi_custom_guide" ])
594+ def data (self , request , eight_schools_params , draws , chains ):
595+ class Data :
596+ obj = load_cached_models (eight_schools_params , draws , chains , "numpyro" )[request .param ]
597+
598+ return Data
599+
600+ # def test_init_defaults(self):
601+ # wrapper = SVIWrapper(
602+
603+ # )
604+
605+ # assert wrapper.svi == "mysvi"
606+ # assert wrapper.svi_result == {"params": {}}
607+ # assert wrapper.num_samples == 1000
608+ # assert wrapper.thinning == 1
609+ # assert wrapper.num_chains == 0
610+ # assert wrapper.sample_dims == ["samples"]
611+ # assert wrapper.kind == "svi"
612+ # assert callable(wrapper.prng_key_func)
613+ # assert hasattr(wrapper, "numpyro")
614+
615+ def test_init_without_args_kwargs (self ):
616+ from numpyro .infer import Trace_ELBO
617+ from numpyro .infer .svi import SVI , SVIRunResult
618+ from numpyro .optim import Adam
619+
620+ model = guide = lambda x : x
621+ svi = SVI (model , guide , optim = Adam (0.05 ), loss = Trace_ELBO ())
622+ svi_result = SVIRunResult (params = jax .numpy .ones (5 ), state = None , losses = jax .numpy .zeros (10 ))
623+
624+ posterior = SVIWrapper (svi , svi_result = svi_result )
625+ assert isinstance (posterior ._args , tuple )
626+ assert isinstance (posterior ._kwargs , dict )
627+
628+ def test_get_samples (self , data , eight_schools_params ):
629+ svi_posterior = SVIWrapper (
630+ data .obj ["svi" ], svi_result = data .obj ["svi_result" ], model_kwargs = eight_schools_params
631+ )
632+ out = svi_posterior .get_samples (seed = 0 )
633+ assert isinstance (out , dict )
634+ for v in out .values (): # values are array-like
635+ assert isinstance (v , (jax .numpy .ndarray | np .ndarray ))
636+
637+ def test_sampler_attr (self , data , eight_schools_params ):
638+ svi_posterior = SVIWrapper (
639+ data .obj ["svi" ], svi_result = data .obj ["svi_result" ], model_kwargs = eight_schools_params
640+ )
641+ assert hasattr (svi_posterior , "sampler" )
642+ assert hasattr (svi_posterior .sampler , "model" )
0 commit comments