@@ -302,7 +302,8 @@ def bfgs_sample(
302302 alpha ,
303303 beta ,
304304 gamma ,
305- random_seed : RandomSeed | None = None ,
305+ # random_seed: RandomSeed | None = None,
306+ rng ,
306307):
307308 # batch: L = 8
308309 # alpha_l: (N,) => (L, N)
@@ -315,7 +316,7 @@ def bfgs_sample(
315316 # logdensity: (M,) => (L, M)
316317 # theta: (J, N)
317318
318- rng = pytensor .shared (np .random .default_rng (seed = random_seed ))
319+ # rng = pytensor.shared(np.random.default_rng(seed=random_seed))
319320
320321 def batched (x , g , alpha , beta , gamma ):
321322 var_list = [x , g , alpha , beta , gamma ]
@@ -380,6 +381,64 @@ def compute_logp(logp_func, arr):
380381 return np .where (np .isnan (logP ), - np .inf , logP )
381382
382383
384+ _x = pt .matrix ("_x" , dtype = "float64" )
385+ _g = pt .matrix ("_g" , dtype = "float64" )
386+ _alpha = pt .matrix ("_alpha" , dtype = "float64" )
387+ _beta = pt .tensor3 ("_beta" , dtype = "float64" )
388+ _gamma = pt .tensor3 ("_gamma" , dtype = "float64" )
389+ _epsilon = pt .scalar ("_epsilon" , dtype = "float64" )
390+ _maxcor = pt .iscalar ("_maxcor" )
391+ _alpha , _S , _Z , _update_mask = alpha_recover (_x , _g , epsilon = _epsilon )
392+ _beta , _gamma = inverse_hessian_factors (_alpha , _S , _Z , _update_mask , J = _maxcor )
393+
394+ _num_elbo_draws = pt .iscalar ("_num_elbo_draws" )
395+ _dummy_rng = pytensor .shared (np .random .default_rng (), name = "_dummy_rng" )
396+ _phi , _logQ_phi = bfgs_sample (
397+ num_samples = _num_elbo_draws ,
398+ x = _x ,
399+ g = _g ,
400+ alpha = _alpha ,
401+ beta = _beta ,
402+ gamma = _gamma ,
403+ rng = _dummy_rng ,
404+ )
405+
406+ _num_draws = pt .iscalar ("_num_draws" )
407+ _x_lstar = pt .dvector ("_x_lstar" )
408+ _g_lstar = pt .dvector ("_g_lstar" )
409+ _alpha_lstar = pt .dvector ("_alpha_lstar" )
410+ _beta_lstar = pt .dmatrix ("_beta_lstar" )
411+ _gamma_lstar = pt .dmatrix ("_gamma_lstar" )
412+
413+
414+ _psi , _logQ_psi = bfgs_sample (
415+ num_samples = _num_draws ,
416+ x = _x_lstar ,
417+ g = _g_lstar ,
418+ alpha = _alpha_lstar ,
419+ beta = _beta_lstar ,
420+ gamma = _gamma_lstar ,
421+ rng = _dummy_rng ,
422+ )
423+
424+ alpha_recover_compiled = pytensor .function (
425+ inputs = [_x , _g , _epsilon ],
426+ outputs = [_alpha , _S , _Z , _update_mask ],
427+ )
428+ inverse_hessian_factors_compiled = pytensor .function (
429+ inputs = [_alpha , _S , _Z , _update_mask , _maxcor ],
430+ outputs = [_beta , _gamma ],
431+ )
432+ bfgs_sample_compiled = pytensor .function (
433+ inputs = [_num_elbo_draws , _x , _g , _alpha , _beta , _gamma ],
434+ outputs = [_phi , _logQ_phi ],
435+ )
436+ bfgs_sample_lstar_compiled = pytensor .function (
437+ inputs = [_num_draws , _x_lstar , _g_lstar , _alpha_lstar , _beta_lstar , _gamma_lstar ],
438+ outputs = [_psi , _logQ_psi ],
439+ )
440+
441+
383442def single_pathfinder (
384443 model ,
385444 num_draws : int ,
@@ -423,47 +482,46 @@ def neg_dlogp_func(x):
423482 maxls = maxls ,
424483 )
425484
426- # x_full, g_full: (L+1, N)
427- x_full = pt .as_tensor (lbfgs_history .x , dtype = "float64" )
428- g_full = pt .as_tensor (lbfgs_history .g , dtype = "float64" )
485+ # x, g: (L+1, N)
486+ x = lbfgs_history .x
487+ g = lbfgs_history .g
488+ alpha , S , Z , update_mask = alpha_recover_compiled (x , g , epsilon )
489+ beta , gamma = inverse_hessian_factors_compiled (alpha , S , Z , update_mask , maxcor )
429490
430491 # ignore initial point - x, g: (L, N)
431- x = x_full [1 :]
432- g = g_full [1 :]
433-
434- alpha , S , Z , update_mask = alpha_recover (x_full , g_full , epsilon = epsilon )
435- beta , gamma = inverse_hessian_factors (alpha , S , Z , update_mask , J = maxcor )
436-
437- phi , logQ_phi = bfgs_sample (
438- num_samples = num_elbo_draws ,
439- x = x ,
440- g = g ,
441- alpha = alpha ,
442- beta = beta ,
443- gamma = gamma ,
444- random_seed = pathfinder_seed ,
492+ x = x [1 :]
493+ g = g [1 :]
494+
495+ rng = pytensor .shared (np .random .default_rng (pathfinder_seed ), borrow = True )
496+ phi , logQ_phi = bfgs_sample_compiled .copy (swap = {_dummy_rng : rng })(
497+ num_elbo_draws ,
498+ x ,
499+ g ,
500+ alpha ,
501+ beta ,
502+ gamma ,
445503 )
446504
447505 # .vectorize is slower than apply_along_axis
448- logP_phi = compute_logp (logp_func , phi . eval () )
449- logQ_phi = logQ_phi .eval ()
506+ logP_phi = compute_logp (logp_func , phi )
507+ # logQ_phi = logQ_phi.eval()
450508 elbo = (logP_phi - logQ_phi ).mean (axis = - 1 )
451509 lstar = np .argmax (elbo )
452510
453511 # BUG: elbo may all be -inf for all l in L. So np.argmax(elbo) will return 0 which is wrong. Still, this won't affect the posterior samples in the multipath Pathfinder scenario because of PSIS/PSIR step. However, the user is left unaware of a failed Pathfinder run.
454512 # TODO: handle this case, e.g. by warning of a failed Pathfinder run and skip the following bfgs_sample step to save time.
455513
456- psi , logQ_psi = bfgs_sample (
457- num_samples = num_draws ,
458- x = x [ lstar ] ,
459- g = g [lstar ],
460- alpha = alpha [lstar ],
461- beta = beta [lstar ],
462- gamma = gamma [lstar ],
463- random_seed = sample_seed ,
514+ rng . set_value ( np . random . default_rng ( sample_seed ), borrow = True )
515+ psi , logQ_psi = bfgs_sample_lstar_compiled . copy ( swap = { _dummy_rng : rng })(
516+ num_draws ,
517+ x [lstar ],
518+ g [lstar ],
519+ alpha [lstar ],
520+ beta [lstar ],
521+ gamma [ lstar ] ,
464522 )
465- psi = psi .eval ()
466- logQ_psi = logQ_psi .eval ()
523+ # psi = psi.eval()
524+ # logQ_psi = logQ_psi.eval()
467525 logP_psi = compute_logp (logp_func , psi )
468526 # psi: (1, M, N)
469527 # logP_psi: (1, M)
0 commit comments