Skip to content

Commit 8b134b7

Browse files
committed
Reduced size of compute graph with pathfinder_body_fn
Summaryh of changes: - Remove multiprocessing code in favour of reusing compiled for each path - takes only random_seed as argument for each path - Compute graph significantly smaller by using pure pytensor op and symoblic variables - Added LBFGSOp to compile with pytensor.function - Cleaned up codes using pytensor variables
1 parent ef2956f commit 8b134b7

File tree

3 files changed

+382
-319
lines changed

3 files changed

+382
-319
lines changed

pymc_experimental/inference/pathfinder/importance_sampling.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,35 @@
22

33
import arviz as az
44
import numpy as np
5+
import pytensor.tensor as pt
6+
7+
from pytensor.graph import Apply, Op
8+
from pytensor.tensor.variable import TensorVariable
59

6-
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
710
logger = logging.getLogger(__name__)
811

912

13+
class PSIS(Op):
14+
__props__ = ()
15+
16+
def make_node(self, inputs):
17+
logweights = pt.as_tensor(inputs)
18+
psislw = pt.dvector()
19+
pareto_k = pt.dscalar()
20+
return Apply(self, [logweights], [psislw, pareto_k])
21+
22+
def perform(self, node: Apply, inputs, outputs) -> None:
23+
logweights = inputs[0]
24+
psislw, pareto_k = az.psislw(logweights)
25+
outputs[0][0] = psislw
26+
outputs[1][0] = pareto_k
27+
28+
1029
def psir(
11-
samples: np.ndarray,
12-
logP: np.ndarray,
13-
logQ: np.ndarray,
30+
samples: TensorVariable,
31+
# logP: TensorVariable,
32+
# logQ: TensorVariable,
33+
logiw: TensorVariable,
1434
num_draws: int = 1000,
1535
random_seed: int | None = None,
1636
) -> np.ndarray:
@@ -48,14 +68,10 @@ def psir(
4868
4969
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
5070
"""
51-
52-
def logsumexp(x):
53-
c = x.max()
54-
return c + np.log(np.sum(np.exp(x - c)))
55-
56-
logiw = np.reshape(logP - logQ, -1, order="F")
57-
psislw, pareto_k = az.psislw(logiw)
58-
71+
# logiw = np.reshape(logP - logQ, (-1,), order="F")
72+
# logiw = (logP - logQ).ravel()
73+
psislw, pareto_k = PSIS()(logiw)
74+
pareto_k = pareto_k.eval()
5975
# FIXME: pareto_k is mostly bad, find out why!
6076
if pareto_k <= 0.70:
6177
pass
@@ -68,6 +84,6 @@ def logsumexp(x):
6884
"consider reparametrising the model, increasing ftol, gtol or maxcor parameters"
6985
)
7086

71-
p = np.exp(psislw - logsumexp(psislw))
87+
p = pt.exp(psislw - pt.logsumexp(psislw)).eval()
7288
rng = np.random.default_rng(random_seed)
73-
return rng.choice(samples, size=num_draws, p=p, shuffle=False, axis=0)
89+
return rng.choice(samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0)

pymc_experimental/inference/pathfinder/lbfgs.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,46 @@
22
from typing import NamedTuple
33

44
import numpy as np
5+
import pytensor.tensor as pt
56

7+
from pytensor.graph import Apply, Op
68
from scipy.optimize import minimize
79

810

911
class LBFGSHistory(NamedTuple):
1012
x: np.ndarray
11-
f: np.ndarray
1213
g: np.ndarray
1314

1415

1516
class LBFGSHistoryManager:
16-
def __init__(self, fn: Callable, grad_fn: Callable, x0: np.ndarray, maxiter: int):
17+
def __init__(self, grad_fn: Callable, x0: np.ndarray, maxiter: int):
1718
dim = x0.shape[0]
1819
maxiter_add_one = maxiter + 1
1920
# Pre-allocate arrays to save memory and improve speed
2021
self.x_history = np.empty((maxiter_add_one, dim), dtype=np.float64)
21-
self.f_history = np.empty(maxiter_add_one, dtype=np.float64)
2222
self.g_history = np.empty((maxiter_add_one, dim), dtype=np.float64)
2323
self.count = 0
24-
self.fn = fn
2524
self.grad_fn = grad_fn
26-
self.add_entry(x0, fn(x0), grad_fn(x0))
25+
self.add_entry(x0, grad_fn(x0))
2726

28-
def add_entry(self, x, f, g=None):
27+
def add_entry(self, x, g):
2928
self.x_history[self.count] = x
30-
self.f_history[self.count] = f
31-
if self.g_history is not None and g is not None:
32-
self.g_history[self.count] = g
29+
self.g_history[self.count] = g
3330
self.count += 1
3431

3532
def get_history(self):
36-
# Return trimmed arrays up to the number of entries actually used
33+
# Return trimmed arrays up to L << L^max
3734
x = self.x_history[: self.count]
38-
f = self.f_history[: self.count]
39-
g = self.g_history[: self.count] if self.g_history is not None else None
35+
g = self.g_history[: self.count]
4036
return LBFGSHistory(
4137
x=x,
42-
f=f,
4338
g=g,
4439
)
4540

4641
def __call__(self, x):
47-
self.add_entry(x, self.fn(x), self.grad_fn(x))
42+
grad = self.grad_fn(x)
43+
if np.all(np.isfinite(grad)):
44+
self.add_entry(x, grad)
4845

4946

5047
def lbfgs(
@@ -62,7 +59,6 @@ def callback(xk):
6259
lbfgs_history_manager(xk)
6360

6461
lbfgs_history_manager = LBFGSHistoryManager(
65-
fn=fn,
6662
grad_fn=grad_fn,
6763
x0=x0,
6864
maxiter=maxiter,
@@ -89,4 +85,58 @@ def callback(xk):
8985
callback=callback,
9086
**lbfgs_kwargs,
9187
)
92-
return lbfgs_history_manager.get_history()
88+
lbfgs_history = lbfgs_history_manager.get_history()
89+
return lbfgs_history.x, lbfgs_history.g
90+
91+
92+
class LBFGSOp(Op):
93+
def __init__(self, fn, grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000):
94+
self.fn = fn
95+
self.grad_fn = grad_fn
96+
self.maxcor = maxcor
97+
self.maxiter = maxiter
98+
self.ftol = ftol
99+
self.gtol = gtol
100+
self.maxls = maxls
101+
102+
def make_node(self, x0):
103+
x0 = pt.as_tensor_variable(x0)
104+
x_history = pt.dmatrix()
105+
g_history = pt.dmatrix()
106+
return Apply(self, [x0], [x_history, g_history])
107+
108+
def perform(self, node, inputs, outputs):
109+
x0 = inputs[0]
110+
x0 = np.array(x0, dtype=np.float64)
111+
112+
history_manager = LBFGSHistoryManager(grad_fn=self.grad_fn, x0=x0, maxiter=self.maxiter)
113+
114+
minimize(
115+
self.fn,
116+
x0,
117+
method="L-BFGS-B",
118+
jac=self.grad_fn,
119+
callback=history_manager,
120+
options={
121+
"maxcor": self.maxcor,
122+
"maxiter": self.maxiter,
123+
"ftol": self.ftol,
124+
"gtol": self.gtol,
125+
"maxls": self.maxls,
126+
},
127+
)
128+
129+
# fmin_l_bfgs_b(
130+
# func=self.fn,
131+
# fprime=self.grad_fn,
132+
# x0=x0,
133+
# pgtol=self.gtol,
134+
# factr=self.ftol / np.finfo(float).eps,
135+
# maxls=self.maxls,
136+
# maxiter=self.maxiter,
137+
# m=self.maxcor,
138+
# callback=history_manager,
139+
# )
140+
141+
outputs[0][0] = history_manager.get_history().x
142+
outputs[1][0] = history_manager.get_history().g

0 commit comments

Comments
 (0)