Skip to content

Commit fdc3f38

Browse files
committed
Removed initial point values (l=0) to reduce iterations. Simplified and .
1 parent 2efb511 commit fdc3f38

File tree

3 files changed

+76
-109
lines changed

3 files changed

+76
-109
lines changed

pymc_experimental/inference/pathfinder/lbfgs.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,14 @@
22
from typing import NamedTuple
33

44
import numpy as np
5-
import pytensor.tensor as pt
65

7-
from pytensor.tensor.variable import TensorVariable
86
from scipy.optimize import minimize
97

108

119
class LBFGSHistory(NamedTuple):
12-
x: TensorVariable
13-
f: TensorVariable
14-
g: TensorVariable
10+
x: np.ndarray
11+
f: np.ndarray
12+
g: np.ndarray
1513

1614

1715
class LBFGSHistoryManager:
@@ -40,9 +38,9 @@ def get_history(self):
4038
f = self.f_history[: self.count]
4139
g = self.g_history[: self.count] if self.g_history is not None else None
4240
return LBFGSHistory(
43-
x=pt.as_tensor(x, "x", dtype="float64"),
44-
f=pt.as_tensor(f, "f", dtype="float64"),
45-
g=pt.as_tensor(g, "g", dtype="float64"),
41+
x=x,
42+
f=f,
43+
g=g,
4644
)
4745

4846
def __call__(self, x):

pymc_experimental/inference/pathfinder/pathfinder.py

Lines changed: 58 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -150,55 +150,6 @@ def convert_flat_trace_to_idata(
150150
return idata
151151

152152

153-
def _get_delta_x_delta_g(x, g):
154-
# x or g: (L - 1, N)
155-
return pt.diff(x, axis=0), pt.diff(g, axis=0)
156-
157-
158-
def _get_chi_matrix(diff, update_mask, J):
159-
_, N = diff.shape
160-
j_last = pt.as_tensor(J - 1) # since indexing starts at 0
161-
162-
def chi_update(chi_lm1, diff_l):
163-
chi_l = pt.roll(chi_lm1, -1, axis=0)
164-
# z_xi_l = pt.set_subtensor(z_xi_l[j_last], z_l)
165-
# z_xi_l[j_last] = z_l
166-
return pt.set_subtensor(chi_l[j_last], diff_l)
167-
168-
def no_op(chi_lm1, diff_l):
169-
return chi_lm1
170-
171-
def scan_body(update_mask_l, diff_l, chi_lm1):
172-
return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))
173-
174-
update_mask = pt.concatenate([pt.as_tensor([False], dtype="bool"), update_mask], axis=-1)
175-
diff = pt.concatenate([pt.zeros((1, N), dtype="float64"), diff], axis=0)
176-
177-
chi_init = pt.zeros((J, N))
178-
chi_mat, _ = pytensor.scan(
179-
fn=scan_body,
180-
outputs_info=chi_init,
181-
sequences=[
182-
update_mask,
183-
diff,
184-
],
185-
)
186-
187-
chi_mat = chi_mat.dimshuffle(0, 2, 1)
188-
189-
return chi_mat
190-
191-
192-
def _get_s_xi_z_xi(x, g, update_mask, J):
193-
L, N = x.shape
194-
S, Z = _get_delta_x_delta_g(x, g)
195-
196-
s_xi = _get_chi_matrix(S, update_mask, J)
197-
z_xi = _get_chi_matrix(Z, update_mask, J)
198-
199-
return s_xi, z_xi
200-
201-
202153
def alpha_recover(x, g, epsilon: float = 1e-11):
203154
"""
204155
epsilon: float
@@ -229,8 +180,9 @@ def scan_body(update_mask_l, s_l, z_l, alpha_lm1):
229180
return_alpha_lm1(alpha_lm1, s_l, z_l),
230181
)
231182

232-
L, N = x.shape
233-
S, Z = _get_delta_x_delta_g(x, g)
183+
Lp1, N = x.shape
184+
S = pt.diff(x, axis=0)
185+
Z = pt.diff(g, axis=0)
234186
alpha_l_init = pt.ones(N)
235187
SZ = (S * Z).sum(axis=-1)
236188

@@ -241,20 +193,54 @@ def scan_body(update_mask_l, s_l, z_l, alpha_lm1):
241193
fn=scan_body,
242194
outputs_info=alpha_l_init,
243195
sequences=[update_mask, S, Z],
244-
n_steps=L - 1,
196+
n_steps=Lp1 - 1,
245197
strict=True,
246198
)
247199

248-
# alpha: (L, N), update_mask: (L-1, N)
249-
alpha = pt.concatenate([pt.ones(N)[None, :], alpha], axis=0)
200+
# alpha: (L, N), update_mask: (L, N)
201+
# alpha = pt.concatenate([pt.ones(N)[None, :], alpha], axis=0)
250202
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
251-
return alpha, update_mask
203+
return alpha, S, Z, update_mask
204+
205+
206+
def inverse_hessian_factors(alpha, S, Z, update_mask, J):
207+
def get_chi_matrix(diff, update_mask, J):
208+
L, N = diff.shape
209+
j_last = pt.as_tensor(J - 1) # since indexing starts at 0
210+
211+
def chi_update(chi_lm1, diff_l):
212+
chi_l = pt.roll(chi_lm1, -1, axis=0)
213+
# z_xi_l = pt.set_subtensor(z_xi_l[j_last], z_l)
214+
# z_xi_l[j_last] = z_l
215+
return pt.set_subtensor(chi_l[j_last], diff_l)
216+
217+
def no_op(chi_lm1, diff_l):
218+
return chi_lm1
219+
220+
def scan_body(update_mask_l, diff_l, chi_lm1):
221+
return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l))
222+
223+
# NOTE: removing first index so that L starts at 1
224+
# update_mask = pt.concatenate([pt.as_tensor([False], dtype="bool"), update_mask], axis=-1)
225+
# diff = pt.concatenate([pt.zeros((1, N), dtype="float64"), diff], axis=0)
226+
227+
chi_init = pt.zeros((J, N))
228+
chi_mat, _ = pytensor.scan(
229+
fn=scan_body,
230+
outputs_info=chi_init,
231+
sequences=[
232+
update_mask,
233+
diff,
234+
],
235+
)
252236

237+
chi_mat = chi_mat.dimshuffle(0, 2, 1)
238+
239+
return chi_mat
253240

254-
def inverse_hessian_factors(alpha, x, g, update_mask, J):
255241
L, N = alpha.shape
256-
# s_xi, z_xi = get_s_xi_z_xi(x, g, update_mask, J)
257-
s_xi, z_xi = _get_s_xi_z_xi(x, g, update_mask, J)
242+
s_xi = get_chi_matrix(S, update_mask, J)
243+
z_xi = get_chi_matrix(Z, update_mask, J)
258244

259245
# (L, J, J)
260246
sz_xi = pt.matrix_transpose(s_xi) @ z_xi
@@ -414,7 +400,7 @@ def neg_dlogp_func(x):
414400
# TODO: apply the above excerpt to the Pathfinder algorithm.
415401
"""
416402

417-
history = lbfgs(
403+
lbfgs_history = lbfgs(
418404
fn=neg_logp_func,
419405
grad_fn=neg_dlogp_func,
420406
x0=ip_map.data,
@@ -425,14 +411,21 @@ def neg_dlogp_func(x):
425411
maxls=maxls,
426412
)
427413

428-
alpha, update_mask = alpha_recover(history.x, history.g, epsilon=epsilon)
414+
# x_full, g_full: (L+1, N)
415+
x_full = pt.as_tensor(lbfgs_history.x, dtype="float64")
416+
g_full = pt.as_tensor(lbfgs_history.g, dtype="float64")
417+
418+
# ignore initial point - x, g: (L, N)
419+
x = x_full[1:]
420+
g = g_full[1:]
429421

430-
beta, gamma = inverse_hessian_factors(alpha, history.x, history.g, update_mask, J=maxcor)
422+
alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon)
423+
beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J=maxcor)
431424

432425
phi, logQ_phi = bfgs_sample(
433426
num_samples=num_elbo_draws,
434-
x=history.x,
435-
g=history.g,
427+
x=x,
428+
g=g,
436429
alpha=alpha,
437430
beta=beta,
438431
gamma=gamma,
@@ -450,8 +443,8 @@ def neg_dlogp_func(x):
450443

451444
psi, logQ_psi = bfgs_sample(
452445
num_samples=num_draws,
453-
x=history.x[lstar],
454-
g=history.g[lstar],
446+
x=x[lstar],
447+
g=g[lstar],
455448
alpha=alpha[lstar],
456449
beta=beta[lstar],
457450
gamma=gamma[lstar],

tests/test_pathfinder.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,28 @@ def test_bfgs_sample():
6262
)
6363

6464
"""test BFGS sampling"""
65-
L, N = 8, 10
65+
Lp1, N = 8, 10
66+
L = Lp1 - 1
6667
J = 6
6768
num_samples = 1000
6869

6970
# mock data
70-
x = np.random.randn(L, N)
71-
g = np.random.randn(L, N)
71+
x_data = np.random.randn(Lp1, N)
72+
g_data = np.random.randn(Lp1, N)
7273

7374
# get factors
74-
x_tensor = pt.as_tensor(x, dtype="float64")
75-
g_tensor = pt.as_tensor(g, dtype="float64")
76-
alpha, update_mask = alpha_recover(x_tensor, g_tensor)
77-
beta, gamma = inverse_hessian_factors(alpha, x_tensor, g_tensor, update_mask, J)
75+
x_full = pt.as_tensor(x_data, dtype="float64")
76+
g_full = pt.as_tensor(g_data, dtype="float64")
77+
x = x_full[1:]
78+
g = g_full[1:]
79+
alpha, S, Z, update_mask = alpha_recover(x_full, g_full)
80+
beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J)
7881

7982
# sample
8083
phi, logq = bfgs_sample(
8184
num_samples=num_samples,
82-
x=x_tensor,
83-
g=g_tensor,
85+
x=x,
86+
g=g,
8487
alpha=alpha,
8588
beta=beta,
8689
gamma=gamma,
@@ -109,30 +112,3 @@ def test_fit_pathfinder_backends(inference_backend):
109112
)
110113
assert isinstance(idata, az.InferenceData)
111114
assert "posterior" in idata
112-
113-
114-
def test_process_multipath_results():
115-
"""Test processing of multipath results"""
116-
from pymc_experimental.inference.pathfinder.pathfinder import (
117-
PathfinderResults,
118-
process_multipath_pathfinder_results,
119-
)
120-
121-
num_paths = 3
122-
num_draws = 100
123-
num_dims = 2
124-
125-
results = PathfinderResults(num_paths, num_draws, num_dims)
126-
127-
# Add data to all paths
128-
for i in range(num_paths):
129-
samples = np.random.randn(num_draws, num_dims)
130-
logP = np.random.randn(num_draws)
131-
logQ = np.random.randn(num_draws)
132-
results.add_path_data(i, samples, logP, logQ)
133-
134-
samples, logP, logQ = process_multipath_pathfinder_results(results)
135-
136-
assert samples.shape == (num_paths * num_draws, num_dims)
137-
assert logP.shape == (num_paths * num_draws,)
138-
assert logQ.shape == (num_paths * num_draws,)

0 commit comments

Comments
 (0)