@@ -519,3 +519,44 @@ def guide():
519
519
assert "guide-always" in called
520
520
assert "model-sometimes" not in called
521
521
assert "guide-sometimes" not in called
522
+
523
+
524
+ def test_log_likelihood_flax_nn ():
525
+ import numpy as np
526
+
527
+ import flax .linen as nn
528
+ from jax import random
529
+
530
+ from numpyro .contrib .module import random_flax_module
531
+
532
+ # Simulate
533
+ rng = np .random .default_rng (99 )
534
+ N = 1000
535
+
536
+ X = rng .normal (0 , 1 , size = (N , 1 ))
537
+ mu = 1 + X @ np .array ([0.5 ])
538
+ y = rng .normal (mu , 0.5 )
539
+
540
+ # Simple linear layer
541
+ class Linear (nn .Module ):
542
+ @nn .compact
543
+ def __call__ (self , x ):
544
+ return nn .Dense (1 , use_bias = True , name = "Dense" )(x )
545
+
546
+ def model (X , y = None ):
547
+ sigma = numpyro .sample ("sigma" , dist .HalfNormal (0.1 ))
548
+ priors = {"Dense.bias" : dist .Normal (0 , 2.5 ), "Dense.kernel" : dist .Normal (0 , 1 )}
549
+ mlp = random_flax_module (
550
+ "mlp" , Linear (), prior = priors , input_shape = (X .shape [1 ],)
551
+ )
552
+ with numpyro .plate ("data" , X .shape [0 ]):
553
+ mu = numpyro .deterministic ("mu" , mlp (X ).squeeze (- 1 ))
554
+ y = numpyro .sample ("y" , dist .Normal (mu , sigma ), obs = y )
555
+
556
+ # Fit model
557
+ kernel = NUTS (model , target_accept_prob = 0.95 )
558
+ mcmc = MCMC (kernel , num_warmup = 100 , num_samples = 100 , num_chains = 1 )
559
+ mcmc .run (random .PRNGKey (0 ), X = X , y = y )
560
+
561
+ # run log likelihood
562
+ numpyro .infer .util .log_likelihood (model , mcmc .get_samples (), X = X , y = y )
0 commit comments