JIT function frozen... minimization failure #10648
Unanswered
jecampagne
asked this question in
Q&A
Replies: 1 comment 24 replies
-
Ha, I got it In fact, In the code below, the function argument of @jit
def _test(x, a):
print("compile....")
return (x-a)**2
class A():
def __init__(self):
tol = 1e-3
method = 'L-BFGS-B'
options = {'disp':False,'ftol':tol, 'gtol':tol, 'maxiter':600}
self.jscMin=jaxopt.ScipyBoundedMinimize(fun=_test,
method=method,
tol=tol,
options=options)
self.Init()
def Init(self):
self.a = None
def set_a(self, val):
self.a = val
# def test(self, x):
# return _test(x, self.a)
def optimize(self, init_val, low_bnd, high_bnd):
res = self.jscMin.run(init_val,bounds=(low_bnd,high_bnd), a=self.a)
return res
It yields
=> 1 single compilation and all minimisations are correct !. It is a good solution or do you see some possible failures? |
Beta Was this translation helpful? Give feedback.
24 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
Trying to make a helper function to allow for jit-method in class, I face a problem in the following use-case of minimization when I would like to minimize the number of recompilation of the
_test
function. For the sake of completeness I use thejaxopt.ScipyBoundedMinimize
minimizer which cannot be jitted.Below is the snippet where for "res 1" I instantiate a new
A
object at each time which yields a recompilation at each time and gives the right expected solution (minimum is equal toa
). For the "res 2" case, a singleA
object is instantiate once for all, and it yields a single compilation of_test
BUT the results are only correct for the firsta
value.I do not yet manage to reconcile 1) right solution and 2) a single compilation. Have you a comment/idea? Thanks in advance.
leading to
Beta Was this translation helpful? Give feedback.
All reactions