@@ -150,55 +150,6 @@ def convert_flat_trace_to_idata(
150
150
return idata
151
151
152
152
153
- def _get_delta_x_delta_g (x , g ):
154
- # x or g: (L - 1, N)
155
- return pt .diff (x , axis = 0 ), pt .diff (g , axis = 0 )
156
-
157
-
158
- def _get_chi_matrix (diff , update_mask , J ):
159
- _ , N = diff .shape
160
- j_last = pt .as_tensor (J - 1 ) # since indexing starts at 0
161
-
162
- def chi_update (chi_lm1 , diff_l ):
163
- chi_l = pt .roll (chi_lm1 , - 1 , axis = 0 )
164
- # z_xi_l = pt.set_subtensor(z_xi_l[j_last], z_l)
165
- # z_xi_l[j_last] = z_l
166
- return pt .set_subtensor (chi_l [j_last ], diff_l )
167
-
168
- def no_op (chi_lm1 , diff_l ):
169
- return chi_lm1
170
-
171
- def scan_body (update_mask_l , diff_l , chi_lm1 ):
172
- return pt .switch (update_mask_l , chi_update (chi_lm1 , diff_l ), no_op (chi_lm1 , diff_l ))
173
-
174
- update_mask = pt .concatenate ([pt .as_tensor ([False ], dtype = "bool" ), update_mask ], axis = - 1 )
175
- diff = pt .concatenate ([pt .zeros ((1 , N ), dtype = "float64" ), diff ], axis = 0 )
176
-
177
- chi_init = pt .zeros ((J , N ))
178
- chi_mat , _ = pytensor .scan (
179
- fn = scan_body ,
180
- outputs_info = chi_init ,
181
- sequences = [
182
- update_mask ,
183
- diff ,
184
- ],
185
- )
186
-
187
- chi_mat = chi_mat .dimshuffle (0 , 2 , 1 )
188
-
189
- return chi_mat
190
-
191
-
192
- def _get_s_xi_z_xi (x , g , update_mask , J ):
193
- L , N = x .shape
194
- S , Z = _get_delta_x_delta_g (x , g )
195
-
196
- s_xi = _get_chi_matrix (S , update_mask , J )
197
- z_xi = _get_chi_matrix (Z , update_mask , J )
198
-
199
- return s_xi , z_xi
200
-
201
-
202
153
def alpha_recover (x , g , epsilon : float = 1e-11 ):
203
154
"""
204
155
epsilon: float
@@ -229,8 +180,9 @@ def scan_body(update_mask_l, s_l, z_l, alpha_lm1):
229
180
return_alpha_lm1 (alpha_lm1 , s_l , z_l ),
230
181
)
231
182
232
- L , N = x .shape
233
- S , Z = _get_delta_x_delta_g (x , g )
183
+ Lp1 , N = x .shape
184
+ S = pt .diff (x , axis = 0 )
185
+ Z = pt .diff (g , axis = 0 )
234
186
alpha_l_init = pt .ones (N )
235
187
SZ = (S * Z ).sum (axis = - 1 )
236
188
@@ -241,20 +193,54 @@ def scan_body(update_mask_l, s_l, z_l, alpha_lm1):
241
193
fn = scan_body ,
242
194
outputs_info = alpha_l_init ,
243
195
sequences = [update_mask , S , Z ],
244
- n_steps = L - 1 ,
196
+ n_steps = Lp1 - 1 ,
245
197
strict = True ,
246
198
)
247
199
248
- # alpha: (L, N), update_mask: (L-1 , N)
249
- alpha = pt .concatenate ([pt .ones (N )[None , :], alpha ], axis = 0 )
200
+ # alpha: (L, N), update_mask: (L, N)
201
+ # alpha = pt.concatenate([pt.ones(N)[None, :], alpha], axis=0)
250
202
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
251
- return alpha , update_mask
203
+ return alpha , S , Z , update_mask
204
+
205
+
206
+ def inverse_hessian_factors (alpha , S , Z , update_mask , J ):
207
+ def get_chi_matrix (diff , update_mask , J ):
208
+ L , N = diff .shape
209
+ j_last = pt .as_tensor (J - 1 ) # since indexing starts at 0
210
+
211
+ def chi_update (chi_lm1 , diff_l ):
212
+ chi_l = pt .roll (chi_lm1 , - 1 , axis = 0 )
213
+ # z_xi_l = pt.set_subtensor(z_xi_l[j_last], z_l)
214
+ # z_xi_l[j_last] = z_l
215
+ return pt .set_subtensor (chi_l [j_last ], diff_l )
216
+
217
+ def no_op (chi_lm1 , diff_l ):
218
+ return chi_lm1
219
+
220
+ def scan_body (update_mask_l , diff_l , chi_lm1 ):
221
+ return pt .switch (update_mask_l , chi_update (chi_lm1 , diff_l ), no_op (chi_lm1 , diff_l ))
222
+
223
+ # NOTE: removing first index so that L starts at 1
224
+ # update_mask = pt.concatenate([pt.as_tensor([False], dtype="bool"), update_mask], axis=-1)
225
+ # diff = pt.concatenate([pt.zeros((1, N), dtype="float64"), diff], axis=0)
226
+
227
+ chi_init = pt .zeros ((J , N ))
228
+ chi_mat , _ = pytensor .scan (
229
+ fn = scan_body ,
230
+ outputs_info = chi_init ,
231
+ sequences = [
232
+ update_mask ,
233
+ diff ,
234
+ ],
235
+ )
252
236
237
+ chi_mat = chi_mat .dimshuffle (0 , 2 , 1 )
238
+
239
+ return chi_mat
253
240
254
- def inverse_hessian_factors (alpha , x , g , update_mask , J ):
255
241
L , N = alpha .shape
256
- # s_xi, z_xi = get_s_xi_z_xi(x, g , update_mask, J)
257
- s_xi , z_xi = _get_s_xi_z_xi ( x , g , update_mask , J )
242
+ s_xi = get_chi_matrix ( S , update_mask , J )
243
+ z_xi = get_chi_matrix ( Z , update_mask , J )
258
244
259
245
# (L, J, J)
260
246
sz_xi = pt .matrix_transpose (s_xi ) @ z_xi
@@ -414,7 +400,7 @@ def neg_dlogp_func(x):
414
400
# TODO: apply the above excerpt to the Pathfinder algorithm.
415
401
"""
416
402
417
- history = lbfgs (
403
+ lbfgs_history = lbfgs (
418
404
fn = neg_logp_func ,
419
405
grad_fn = neg_dlogp_func ,
420
406
x0 = ip_map .data ,
@@ -425,14 +411,21 @@ def neg_dlogp_func(x):
425
411
maxls = maxls ,
426
412
)
427
413
428
- alpha , update_mask = alpha_recover (history .x , history .g , epsilon = epsilon )
414
+ # x_full, g_full: (L+1, N)
415
+ x_full = pt .as_tensor (lbfgs_history .x , dtype = "float64" )
416
+ g_full = pt .as_tensor (lbfgs_history .g , dtype = "float64" )
417
+
418
+ # ignore initial point - x, g: (L, N)
419
+ x = x_full [1 :]
420
+ g = g_full [1 :]
429
421
430
- beta , gamma = inverse_hessian_factors (alpha , history .x , history .g , update_mask , J = maxcor )
422
+ alpha , S , Z , update_mask = alpha_recover (x_full , g_full , epsilon = epsilon )
423
+ beta , gamma = inverse_hessian_factors (alpha , S , Z , update_mask , J = maxcor )
431
424
432
425
phi , logQ_phi = bfgs_sample (
433
426
num_samples = num_elbo_draws ,
434
- x = history . x ,
435
- g = history . g ,
427
+ x = x ,
428
+ g = g ,
436
429
alpha = alpha ,
437
430
beta = beta ,
438
431
gamma = gamma ,
@@ -450,8 +443,8 @@ def neg_dlogp_func(x):
450
443
451
444
psi , logQ_psi = bfgs_sample (
452
445
num_samples = num_draws ,
453
- x = history . x [lstar ],
454
- g = history . g [lstar ],
446
+ x = x [lstar ],
447
+ g = g [lstar ],
455
448
alpha = alpha [lstar ],
456
449
beta = beta [lstar ],
457
450
gamma = gamma [lstar ],
0 commit comments