1+ import logging
12from collections .abc import Sequence
23from copy import copy
34from typing import cast
89from scipy .optimize import root as scipy_root
910
1011from pytensor import Variable , function , graph_replace
11- from pytensor .gradient import grad , jacobian
12+ from pytensor .gradient import grad , hessian , jacobian
1213from pytensor .graph import Apply , Constant , FunctionGraph
1314from pytensor .graph .basic import graph_inputs , truncated_graph_inputs
1415from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
2021from pytensor .tensor .variable import TensorVariable
2122
2223
24+ _log = logging .getLogger (__name__ )
25+
26+
27+ class LRUCache1 :
28+ """
29+ Simple LRU cache with a memory size of 1.
30+
31+ This cache is only usable for a function that takes a single input `x` and returns a single output. The
32+ function can also take any number of additional arguments `*args`, but these are assumed to be constant
33+ between function calls.
34+
35+ The purpose of this cache is to allow for Hessian computation to be reused when calling scipy.optimize functions.
36+ It is very often the case that some sub-computations are repeated between the objective, gradient, and hessian
37+ functions, but by default scipy only allows for the objective and gradient to be fused.
38+
39+ By using this cache, all 3 functions can be fused, which can significantly speed up the optimization process for
40+ expensive functions.
41+ """
42+
43+ def __init__ (self , fn ):
44+ self .fn = fn
45+ self .last_x = None
46+ self .last_result = None
47+
48+ self .cache_hits = 0
49+ self .cache_misses = 0
50+
51+ self .value_and_grad_calls = 0
52+ self .hess_calls = 0
53+
54+ def __call__ (self , x , * args ):
55+ """
56+ Call the cached function with the given input `x` and additional arguments `*args`.
57+
58+ If the input `x` is the same as the last input, return the cached result. Otherwise update the cache with the
59+ new input and result.
60+ """
61+ cache_hit = np .all (x == self .last_x )
62+
63+ if self .last_x is None or not cache_hit :
64+ self .cache_misses += 1
65+ result = self .fn (x , * args )
66+ self .last_x = x
67+ self .last_result = result
68+ return result
69+
70+ else :
71+ self .cache_hits += 1
72+ return self .last_result
73+
74+ def value (self , x , * args ):
75+ self .value_and_grad_calls += 1
76+ res = self (x , * args )
77+ if isinstance (res , tuple ):
78+ return res [0 ]
79+ else :
80+ return res
81+
82+ def value_and_grad (self , x , * args ):
83+ self .value_and_grad_calls += 1
84+ return self (x , * args )[:2 ]
85+
86+ def hess (self , x , * args ):
87+ self .hess_calls += 1
88+ return self (x , * args )[- 1 ]
89+
90+ def report (self ):
91+ _log .info (f"Value and Grad calls: { self .value_and_grad_calls } " )
92+ _log .info (f"Hess Calls: { self .hess_calls } " )
93+ _log .info (f"Hits: { self .cache_hits } " )
94+ _log .info (f"Misses: { self .cache_misses } " )
95+
96+ def clear_cache (self ):
97+ self .last_x = None
98+ self .last_result = None
99+ self .cache_hits = 0
100+ self .cache_misses = 0
101+ self .value_and_grad_calls = 0
102+ self .hess_calls = 0
103+
104+
23105class ScipyWrapperOp (Op , HasInnerGraph ):
24106 """Shared logic for scipy optimization ops"""
25107
@@ -44,9 +126,9 @@ def build_fn(self):
44126 def fn_wrapper (x , * args ):
45127 return fn (x .squeeze (), * args )
46128
47- self ._fn_wrapped = fn_wrapper
129+ self ._fn_wrapped = LRUCache1 ( fn_wrapper )
48130 else :
49- self ._fn_wrapped = fn
131+ self ._fn_wrapped = LRUCache1 ( fn )
50132
51133 @property
52134 def fn (self ):
@@ -120,6 +202,7 @@ def perform(self, node, inputs, outputs):
120202 ** self .optimizer_kwargs ,
121203 )
122204
205+ f .clear_cache ()
123206 outputs [0 ][0 ] = np .array (res .x )
124207 outputs [1 ][0 ] = np .bool_ (res .success )
125208
@@ -211,6 +294,12 @@ def __init__(
211294 )
212295 self .fgraph .add_output (grad_wrt_x )
213296
297+ if hess :
298+ hess_wrt_x = cast (
299+ Variable , hessian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
300+ )
301+ self .fgraph .add_output (hess_wrt_x )
302+
214303 self .jac = jac
215304 self .hess = hess
216305 self .hessp = hessp
@@ -225,14 +314,17 @@ def perform(self, node, inputs, outputs):
225314 x0 , * args = inputs
226315
227316 res = scipy_minimize (
228- fun = f ,
317+ fun = f . value_and_grad if self . jac else f . value ,
229318 jac = self .jac ,
230319 x0 = x0 ,
231320 args = tuple (args ),
321+ hess = f .hess if self .hess else None ,
232322 method = self .method ,
233323 ** self .optimizer_kwargs ,
234324 )
235325
326+ f .clear_cache ()
327+
236328 outputs [0 ][0 ] = res .x
237329 outputs [1 ][0 ] = np .bool_ (res .success )
238330
@@ -283,6 +375,7 @@ def minimize(
283375 x : TensorVariable ,
284376 method : str = "BFGS" ,
285377 jac : bool = True ,
378+ hess : bool = False ,
286379 optimizer_kwargs : dict | None = None ,
287380):
288381 """
@@ -325,6 +418,7 @@ def minimize(
325418 objective = objective ,
326419 method = method ,
327420 jac = jac ,
421+ hess = hess ,
328422 optimizer_kwargs = optimizer_kwargs ,
329423 )
330424
0 commit comments