@@ -150,55 +150,6 @@ def convert_flat_trace_to_idata(
150150 return idata
151151
152152
153- def _get_delta_x_delta_g (x , g ):
154- # x or g: (L - 1, N)
155- return pt .diff (x , axis = 0 ), pt .diff (g , axis = 0 )
156-
157-
158- def _get_chi_matrix (diff , update_mask , J ):
159- _ , N = diff .shape
160- j_last = pt .as_tensor (J - 1 ) # since indexing starts at 0
161-
162- def chi_update (chi_lm1 , diff_l ):
163- chi_l = pt .roll (chi_lm1 , - 1 , axis = 0 )
164- # z_xi_l = pt.set_subtensor(z_xi_l[j_last], z_l)
165- # z_xi_l[j_last] = z_l
166- return pt .set_subtensor (chi_l [j_last ], diff_l )
167-
168- def no_op (chi_lm1 , diff_l ):
169- return chi_lm1
170-
171- def scan_body (update_mask_l , diff_l , chi_lm1 ):
172- return pt .switch (update_mask_l , chi_update (chi_lm1 , diff_l ), no_op (chi_lm1 , diff_l ))
173-
174- update_mask = pt .concatenate ([pt .as_tensor ([False ], dtype = "bool" ), update_mask ], axis = - 1 )
175- diff = pt .concatenate ([pt .zeros ((1 , N ), dtype = "float64" ), diff ], axis = 0 )
176-
177- chi_init = pt .zeros ((J , N ))
178- chi_mat , _ = pytensor .scan (
179- fn = scan_body ,
180- outputs_info = chi_init ,
181- sequences = [
182- update_mask ,
183- diff ,
184- ],
185- )
186-
187- chi_mat = chi_mat .dimshuffle (0 , 2 , 1 )
188-
189- return chi_mat
190-
191-
192- def _get_s_xi_z_xi (x , g , update_mask , J ):
193- L , N = x .shape
194- S , Z = _get_delta_x_delta_g (x , g )
195-
196- s_xi = _get_chi_matrix (S , update_mask , J )
197- z_xi = _get_chi_matrix (Z , update_mask , J )
198-
199- return s_xi , z_xi
200-
201-
202153def alpha_recover (x , g , epsilon : float = 1e-11 ):
203154 """
204155 epsilon: float
@@ -229,8 +180,9 @@ def scan_body(update_mask_l, s_l, z_l, alpha_lm1):
229180 return_alpha_lm1 (alpha_lm1 , s_l , z_l ),
230181 )
231182
232- L , N = x .shape
233- S , Z = _get_delta_x_delta_g (x , g )
183+ Lp1 , N = x .shape
184+ S = pt .diff (x , axis = 0 )
185+ Z = pt .diff (g , axis = 0 )
234186 alpha_l_init = pt .ones (N )
235187 SZ = (S * Z ).sum (axis = - 1 )
236188
@@ -241,20 +193,54 @@ def scan_body(update_mask_l, s_l, z_l, alpha_lm1):
241193 fn = scan_body ,
242194 outputs_info = alpha_l_init ,
243195 sequences = [update_mask , S , Z ],
244- n_steps = L - 1 ,
196+ n_steps = Lp1 - 1 ,
245197 strict = True ,
246198 )
247199
248- # alpha: (L, N), update_mask: (L-1 , N)
249- alpha = pt .concatenate ([pt .ones (N )[None , :], alpha ], axis = 0 )
200+ # alpha: (L, N), update_mask: (L, N)
201+ # alpha = pt.concatenate([pt.ones(N)[None, :], alpha], axis=0)
250202 # assert np.all(alpha.eval() > 0), "alpha cannot be negative"
251- return alpha , update_mask
203+ return alpha , S , Z , update_mask
204+
205+
206+ def inverse_hessian_factors (alpha , S , Z , update_mask , J ):
207+ def get_chi_matrix (diff , update_mask , J ):
208+ L , N = diff .shape
209+ j_last = pt .as_tensor (J - 1 ) # since indexing starts at 0
210+
211+ def chi_update (chi_lm1 , diff_l ):
212+ chi_l = pt .roll (chi_lm1 , - 1 , axis = 0 )
213+ # z_xi_l = pt.set_subtensor(z_xi_l[j_last], z_l)
214+ # z_xi_l[j_last] = z_l
215+ return pt .set_subtensor (chi_l [j_last ], diff_l )
216+
217+ def no_op (chi_lm1 , diff_l ):
218+ return chi_lm1
219+
220+ def scan_body (update_mask_l , diff_l , chi_lm1 ):
221+ return pt .switch (update_mask_l , chi_update (chi_lm1 , diff_l ), no_op (chi_lm1 , diff_l ))
222+
223+ # NOTE: removing first index so that L starts at 1
224+ # update_mask = pt.concatenate([pt.as_tensor([False], dtype="bool"), update_mask], axis=-1)
225+ # diff = pt.concatenate([pt.zeros((1, N), dtype="float64"), diff], axis=0)
226+
227+ chi_init = pt .zeros ((J , N ))
228+ chi_mat , _ = pytensor .scan (
229+ fn = scan_body ,
230+ outputs_info = chi_init ,
231+ sequences = [
232+ update_mask ,
233+ diff ,
234+ ],
235+ )
252236
237+ chi_mat = chi_mat .dimshuffle (0 , 2 , 1 )
238+
239+ return chi_mat
253240
254- def inverse_hessian_factors (alpha , x , g , update_mask , J ):
255241 L , N = alpha .shape
256- # s_xi, z_xi = get_s_xi_z_xi(x, g , update_mask, J)
257- s_xi , z_xi = _get_s_xi_z_xi ( x , g , update_mask , J )
242+ s_xi = get_chi_matrix ( S , update_mask , J )
243+ z_xi = get_chi_matrix ( Z , update_mask , J )
258244
259245 # (L, J, J)
260246 sz_xi = pt .matrix_transpose (s_xi ) @ z_xi
@@ -414,7 +400,7 @@ def neg_dlogp_func(x):
414400 # TODO: apply the above excerpt to the Pathfinder algorithm.
415401 """
416402
417- history = lbfgs (
403+ lbfgs_history = lbfgs (
418404 fn = neg_logp_func ,
419405 grad_fn = neg_dlogp_func ,
420406 x0 = ip_map .data ,
@@ -425,14 +411,21 @@ def neg_dlogp_func(x):
425411 maxls = maxls ,
426412 )
427413
428- alpha , update_mask = alpha_recover (history .x , history .g , epsilon = epsilon )
414+ # x_full, g_full: (L+1, N)
415+ x_full = pt .as_tensor (lbfgs_history .x , dtype = "float64" )
416+ g_full = pt .as_tensor (lbfgs_history .g , dtype = "float64" )
417+
418+ # ignore initial point - x, g: (L, N)
419+ x = x_full [1 :]
420+ g = g_full [1 :]
429421
430- beta , gamma = inverse_hessian_factors (alpha , history .x , history .g , update_mask , J = maxcor )
422+ alpha , S , Z , update_mask = alpha_recover (x_full , g_full , epsilon = epsilon )
423+ beta , gamma = inverse_hessian_factors (alpha , S , Z , update_mask , J = maxcor )
431424
432425 phi , logQ_phi = bfgs_sample (
433426 num_samples = num_elbo_draws ,
434- x = history . x ,
435- g = history . g ,
427+ x = x ,
428+ g = g ,
436429 alpha = alpha ,
437430 beta = beta ,
438431 gamma = gamma ,
@@ -450,8 +443,8 @@ def neg_dlogp_func(x):
450443
451444 psi , logQ_psi = bfgs_sample (
452445 num_samples = num_draws ,
453- x = history . x [lstar ],
454- g = history . g [lstar ],
446+ x = x [lstar ],
447+ g = g [lstar ],
455448 alpha = alpha [lstar ],
456449 beta = beta [lstar ],
457450 gamma = gamma [lstar ],
0 commit comments