3434from pymc .model .core import Point
3535from pymc .sampling .jax import get_jaxified_graph
3636from pymc .util import RandomSeed , _get_seeds_per_chain , get_default_varnames
37+ from pytensor .graph import Apply , Op
3738
3839from pymc_experimental .inference .lbfgs import lbfgs
3940
@@ -311,7 +312,8 @@ def bfgs_sample(
311312 alpha ,
312313 beta ,
313314 gamma ,
314- random_seed : RandomSeed | None = None ,
315+ rng ,
316+ # random_seed: RandomSeed | None = None,
315317):
316318 # batch: L = 8
317319 # alpha_l: (N,) => (L, N)
@@ -324,8 +326,6 @@ def bfgs_sample(
324326 # logdensity: (M,) => (L, M)
325327 # theta: (J, N)
326328
327- rng = pytensor .shared (np .random .default_rng (seed = random_seed ))
328-
329329 if not _batched (x , g , alpha , beta , gamma ):
330330 x = pt .atleast_2d (x )
331331 g = pt .atleast_2d (g )
@@ -371,6 +371,24 @@ def bfgs_sample(
371371 return phi , logdensity
372372
373373
374+ class LogLike (Op ):
375+ def __init__ (self , logp_func ):
376+ self .logp_func = logp_func
377+ super ().__init__ ()
378+
379+ def make_node (self , phi_node ):
380+ # Convert inputs to tensor variables
381+ phi_node = pt .as_tensor (phi_node )
382+ output_type = pt .tensor (dtype = phi_node .dtype , shape = (None , None ))
383+ return Apply (self , [phi_node ], [output_type ])
384+
385+ def perform (self , node : Apply , phi_node , outputs ) -> None :
386+ phi_node = phi_node [0 ]
387+ logp_node = np .apply_along_axis (self .logp_func , axis = - 1 , arr = phi_node )
388+ # outputs[0][0] = np.asarray(logp)
389+ outputs [0 ][0 ] = logp_node
390+
391+
374392def _pymc_pathfinder (
375393 model ,
376394 x0 : np .float64 ,
@@ -406,38 +424,43 @@ def neg_dlogp_func(x):
406424 gtol = gtol ,
407425 maxls = maxls ,
408426 )
427+ x = pytensor .shared (history .x , "x" )
428+ g = pytensor .shared (history .g , "g" )
409429
410- alpha , update_mask = alpha_recover (history .x , history .g )
411-
412- beta , gamma = inverse_hessian_factors (alpha , history .x , history .g , update_mask , J = maxcor )
430+ alpha , update_mask = alpha_recover (x , g )
413431
414- phi , logq_phi = bfgs_sample (
432+ beta , gamma = inverse_hessian_factors (alpha , x , g , update_mask , J = maxcor )
433+ rng = pytensor .shared (np .random .default_rng (seed = pathfinder_seed ))
434+ _phi , _logq_phi = bfgs_sample (
415435 num_samples = num_elbo_draws ,
416- x = history . x ,
417- g = history . g ,
436+ x = x ,
437+ g = g ,
418438 alpha = alpha ,
419439 beta = beta ,
420440 gamma = gamma ,
421- random_seed = pathfinder_seed ,
441+ rng = rng ,
422442 )
443+ sample_phi_fn = pytensor .function ([alpha , beta , gamma ], [_phi , _logq_phi ])
444+ phi , logq_phi = sample_phi_fn (alpha .eval (), beta .eval (), gamma .eval ())
423445
424446 # .vectorize is slower than apply_along_axis
425- logp_phi = np . apply_along_axis (logp_func , axis = - 1 , arr = phi . eval () )
426- logq_phi = logq_phi . eval ( )
427- elbo = (logp_phi - logq_phi ). mean ( axis = - 1 )
428- lstar = np .argmax (elbo )
447+ loglike = LogLike (logp_func )
448+ logp_phi = loglike ( phi )
449+ elbo = pt . mean (logp_phi - logq_phi , axis = - 1 )
450+ l_star = pt .argmax (elbo )
429451
452+ rng .set_value (np .random .default_rng (seed = sample_seed ))
430453 psi , logq_psi = bfgs_sample (
431454 num_samples = num_draws ,
432- x = history . x [ lstar ],
433- g = history . g [ lstar ],
434- alpha = alpha [lstar ],
435- beta = beta [lstar ],
436- gamma = gamma [lstar ],
437- random_seed = sample_seed ,
455+ x = x [ l_star ],
456+ g = g [ l_star ],
457+ alpha = alpha [l_star ],
458+ beta = beta [l_star ],
459+ gamma = gamma [l_star ],
460+ rng = rng ,
438461 )
439462
440- return psi [0 ].eval (), logq_psi , logp_func
463+ return psi [0 ].eval (), logq_psi . eval ()
441464
442465
443466def fit_pathfinder (
@@ -492,7 +515,7 @@ def fit_pathfinder(
492515
493516 # TODO: make better
494517 if inference_backend == "pymc" :
495- pathfinder_samples , logq_psi , logp_func = _pymc_pathfinder (
518+ pathfinder_samples , logq_psi = _pymc_pathfinder (
496519 model ,
497520 ip ,
498521 maxcor = maxcor ,
0 commit comments