@@ -1463,8 +1463,10 @@ def _sample_latent(self, *args, **kwargs):
14631463 if self .global_guide is not None :
14641464 global_latents = self .global_guide (* args , ** kwargs )
14651465 rng_key = numpyro .prng_key ()
1466- with handlers .block (), handlers .seed (rng_seed = rng_key ), handlers .substitute (
1467- data = global_latents
1466+ with (
1467+ handlers .block (),
1468+ handlers .seed (rng_seed = rng_key ),
1469+ handlers .substitute (data = global_latents ),
14681470 ):
14691471 global_outputs = self .global_guide .model (* args , ** kwargs )
14701472 local_args = (global_outputs ,)
@@ -1575,9 +1577,12 @@ def fn(x):
15751577 if self .local_guide is not None :
15761578 key = numpyro .prng_key ()
15771579 subsample_guide = partial (_subsample_model , self .local_guide )
1578- with handlers .block (), handlers .trace () as tr , handlers .seed (
1579- rng_seed = key
1580- ), handlers .substitute (data = local_guide_params ):
1580+ with (
1581+ handlers .block (),
1582+ handlers .trace () as tr ,
1583+ handlers .seed (rng_seed = key ),
1584+ handlers .substitute (data = local_guide_params ),
1585+ ):
15811586 with warnings .catch_warnings ():
15821587 warnings .simplefilter ("ignore" )
15831588 subsample_guide (* local_args , ** local_kwargs )
0 commit comments