You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am seeking advice on how to reduce memory usage for the optimization problem I describe below.
Thank you to the developers of jax for this amazing package. I am a newcomer to jax and python and Ubuntu have limited knowledge of optimization, so please excuse if I miss something obvious in my questions and comments below.
The problem: finding an appropriate low-memory-usage approach
I am solving a highly nonlinear minimization problem, which for current testing purposes has ~1,000 variables in my testing stage. This might rise to 10,000 or more variables in the near future. The objective function is quite complex; an analytic Jacobian is not practical. There can be numerical challenges (although I am trying to use float64 for everything) and finite differences don't always work well. However, jacfwd works great. I have no difficulty solving the problem at its current size with several different methods.
My goal is to have at least one robust method of solving this problem, particularly as it scales up, that uses very little memory. All of my efforts at this, which I describe below, have failed.
I am solving this on an Ubuntu 20.10 machine with 64gb of RAM and no GPU. I am using jax 0.2.13 and jaxlib 1.65.
The problem can be structured in several nearly-equivalent ways:
as a root-finding problem with a vector valued-function that has the same number of inputs and outputs (e.g., if there are 1,000 variables, the function returns 1,000 results)
as a least-squares problem for optimizers that minimize the sum of squares of a vector-valued function; it has the same structure, where 1,000 (in this example) residuals are returned
as a sum of squares minimization problem for optimizers that minimize a function with a scalar result (the sum of squared residuals)
The Jacobian can be quite dense, although sometimes it is sparse.
I have no problem solving this with 1,000 variables using scipy.optimize.least_squares with the full Jacobian computed with jax.jacfwd, and using the important scipy.optimize.least_squares option x_scale='jac' (more on that below). However, as it grows I will want to avoid both the memory usage and computational expense of the full Jacobian. I have found that larger versions of my problem use more than all of my physical ram, using swap space as well, and take a lot of time to construct the Jacobian.
Attempts at finding an appropriate low-memory-usage approach
Here are my failed efforts at solving it using very little memory.
I have simplified the code snippets below to remove complexity related to additional function arguments, to use simple and meaningful variable names, and to remove additional options for optimization routines. I don't believe I made any errors when simplifying, but the simplified code is not tested. As far as I can tell the full operational code works exactly as intended.
I have approached it as a root-finding problem and hand-coded Newton's method, but instead of using the full Jacobian to calculate each step, I use a linear operator that uses the jax-computed Jacobian vector product and I solve for the step using scipy.optimize.lsq_linear (scipy, not jax.scipy, as lsq_linear is not available in jax as far as I can see). (The Jacobian is highly ill-conditioned which is why I don't simply use jax.scipy.linalg.solve to get the step.) This is a low-memory approach to well-behaved problems but the main problem for me with Newton's method is that its success depends greatly on the starting point. I don't know how to choose an excellent starting point and so it fails with many real-world problems. Thus, my current thinking is that Newton's method with the jvp (wrapped in a linear operator) will not work for my problems.
To make things concrete, here is a simplified version of the code that accomplishes this, where: f is a vector valued function of the vector x that I wish to solve for
it returns f_at_x, which is a vector of the same length as x (1,000 in my example)
I define a function, get_linop, that wraps jvp and vjp functions in a scipy linear operator.
Within each Newton iteration, it calls this function and solves for the step vector needed for the next Newton iteration.
Define the linear operator function:
def get_linop(x, f, f_at_x):
# function to get a linear operator at point x
l_jvp = lambda f_at_x: jvp(f, (x,), (f_at_x,))[1]
l_vjp = lambda f_at_x: vjp(f, x)[1](f_at_x)
linop = scipy.sparse.linalg.LinearOperator((x.size, x.size),
matvec=l_jvp, rmatvec=l_vjp)
return linop
As noted, this works fine in the sense that it is fast, uses little memory, and is correct with well-behaved problems. But in a challenging problem, if the initial x is not good it will not compute a good step or converge to a good solution. Therefore I am inclined against Newton's method.
I have solved it directly as a nonlinear least squares problem, using scipy.optimize.least_squares (from scipy, not jax.scipy -- as far as I can tell, this is not implemented in jax). This works fine with the full Jacobian. I tried two variants in an effort to reduce memory usage:
a) Constructing the Jacobian a column at a time or a row at a time using jvp or vjp, following examples in the Autodiff cookbook. I did this because I thought maybe it was not the Jacobian itself that was taking a lot of memory, but the process of constructing the Jacobian, and perhaps column-by-column or row-by-row construction would not use as much memory. Either approach solved my problem just fine, but didn't seem to use any less memory than the full Jacobian.
Here is a simplified snippet of code to show how I implement the jvp version of this (I do not show the vjp version); f, x, and f_at_x are the same as above, and myx0 is a simple starting point that may not be near the true solution:
def jacfun(x):
# function to build Jacobian column by column in effort to conserve memory use
_jvp = lambda f_at_x: jax.jvp(f, (x,), (f_at_X,))[1]
Jt = jax.vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
def Jsolver(x):
# return Jacobian at a specific point, as numpy array
jac_values = jacfun(x)
jac_values = np.array(jac_values).reshape((x.size, x.size))
return jac_values
result = scipy.optimize.least_squares(
fun=f,
x0=myx0,
method='trf',
jac=Jsolver,
x_scale='jac') # an important scaling option provided by scipy.optimize.least_squares
As mentioned, this works fine on all problems, but as far as I can tell building the Jacobian column-by-column with the jvp approach does not appear to use less memory than if I computed the full Jacobian with jacfwd.
b) I took advantage of scipy.sparse.linalg.LinearOperator (not jax.scipy...) and constructed a linear operator using jvp and vjp, which I passed to scipy.optimize.least_squares rather than using the full Jacobian. This is extremely fast and extremely low memory usage, and it works great on some problems. However, when using a linear operator, scipy.optimize.least_squares does not allow you to use the variable scaling option x_scale='jac'; this turns out to be a crucial option for ill-behaved versions of my problem, which is most of them. Scipy allows you to construct your own scaling factors as an alternative, but I do not know a way to construct good scaling factors for my problem's variables. Furthermore, even if I did, they must be fixed scaling factors that do not change from iteration to iteration, whereas x_scale='jac' has factors that change during the course of the solution.
As a result, I currently think I do not have a viable low-memory solution using the least-squares minimization approach with scipy.optimize.least_squares.
Here is the simplified implementation of this approach:
# define a linear operator to pass to scipy's nonlinear least_squares:
def jvp_linop(x):
# create a linear operator using jvp and vjp
def f_jvp(f_at_x):
f_at_x = f_at_x.reshape(f_at_x.size) # reshaping is required by the linear operator
return jax.jvp(f, (x,), (f_at_x,))[1]
l_vjp = lambda f_at_x: jax.vjp(f, x)[1](f_at_x)
linop = scipy.sparse.linalg.LinearOperator((x.size, x.size),
matvec=f_jvp, rmatvec=l_vjp)
return linop
result = scipy.optimize.least_squares(
fun=f,
x0=myx0,
method='trf',
jac=jvp_linop, # note that we use the linear operator here
x_scale=1.0) # least_squares does not allow x_scale='jac' (as above) when using linear operator so use 1.0 default
As noted, this is fast AND uses little memory. But because we cannot use the scipy...least_squares option of x_scale='jac' it produces bad results on difficult problems.
I tried a completely different approach, trying to minimize the sum of squared residuals using jax.scipy.optimize.minimize. (Incidentally, if any jax developers are reading this, FWIW I seem unable to call this by importing jax and using jax.scipy.optimize.minimize; instead I have to do something like "from jax.scipy.optimize import minimize", and then it works fine. It took me a lot of hair-pulling to figure that out.)
All you have to do here is pass the objective function that computes the SSR to jax.scipy.optimize.minimize (the jax.scipy version, not the scipy version) and it does whatever it does under the hood, using the BFGS algorithm. I figured jax probably would use memory-saving methods such as the Jacobian vector product (or perhaps it is called the Hessian vector product in this case). However, the documentation is spartan and there seem to be few options. Although I don't know exactly what this method is doing, it works great on well-behaved problems -- it is super fast, produces better objective function values than some other methods, and uses almost no memory. Unfortunately, it seems to fail as soon as I try it on real-world problems of modest size (e.g., 1,000 variables). Because it does not return a lot of information I don't know exactly how it fails -- I can see that when it fails it usually does no more than 40 function evaluations and it does only 1 major iteration. It returns a success code of 3, which clearly indicates failure but I am not sure what, precisely, it means. In any event, at the moment, I am ruling this approach out as a way to keep memory usage low.
Here is simplified code for this approach. I leave out the BFGS options I used, which didn't seem to have any effect:
# as noted, jax won't allow using full name of minimize so we must import the function
from jax.scipy.optimize import minimize
def f_sumsq(x):
return jax.numpy.square(f(x)).sum()
result = minimize(fun=f_sumsq,
x0=myx0,
method='BFGS')
As noted, this is fast and uses low memory, and is correct for easy problems. However, it fails for indeterminate reasons (the returned success code is 3) for more-difficult problems that can easily be solved using the full Jacobian with scipy's least_squares and the x_scale='jac' option.
Seeking advice
I am not completely out of ideas for how to solve larger versions of this problem while using relatively little memory, but I am close to out of ideas. I would much appreciate any advice on how this might be done.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am seeking advice on how to reduce memory usage for the optimization problem I describe below.
Thank you to the developers of jax for this amazing package. I am a newcomer to jax and python and Ubuntu have limited knowledge of optimization, so please excuse if I miss something obvious in my questions and comments below.
The problem: finding an appropriate low-memory-usage approach
I am solving a highly nonlinear minimization problem, which for current testing purposes has ~1,000 variables in my testing stage. This might rise to 10,000 or more variables in the near future. The objective function is quite complex; an analytic Jacobian is not practical. There can be numerical challenges (although I am trying to use float64 for everything) and finite differences don't always work well. However, jacfwd works great. I have no difficulty solving the problem at its current size with several different methods.
My goal is to have at least one robust method of solving this problem, particularly as it scales up, that uses very little memory. All of my efforts at this, which I describe below, have failed.
I am solving this on an Ubuntu 20.10 machine with 64gb of RAM and no GPU. I am using jax 0.2.13 and jaxlib 1.65.
The problem can be structured in several nearly-equivalent ways:
The Jacobian can be quite dense, although sometimes it is sparse.
I have no problem solving this with 1,000 variables using
scipy.optimize.least_squares
with the full Jacobian computed withjax.jacfwd
, and using the importantscipy.optimize.least_squares
optionx_scale='jac'
(more on that below). However, as it grows I will want to avoid both the memory usage and computational expense of the full Jacobian. I have found that larger versions of my problem use more than all of my physical ram, using swap space as well, and take a lot of time to construct the Jacobian.Attempts at finding an appropriate low-memory-usage approach
Here are my failed efforts at solving it using very little memory.
I have simplified the code snippets below to remove complexity related to additional function arguments, to use simple and meaningful variable names, and to remove additional options for optimization routines. I don't believe I made any errors when simplifying, but the simplified code is not tested. As far as I can tell the full operational code works exactly as intended.
To make things concrete, here is a simplified version of the code that accomplishes this, where:
f
is a vector valued function of the vectorx
that I wish to solve forit returns
f_at_x
, which is a vector of the same length asx
(1,000 in my example)I define a function,
get_linop
, that wraps jvp and vjp functions in a scipy linear operator.Within each Newton iteration, it calls this function and solves for the step vector needed for the next Newton iteration.
Define the linear operator function:
Snippet of code inside each Newton iteration:
As noted, this works fine in the sense that it is fast, uses little memory, and is correct with well-behaved problems. But in a challenging problem, if the initial
x
is not good it will not compute a good step or converge to a good solution. Therefore I am inclined against Newton's method.a) Constructing the Jacobian a column at a time or a row at a time using jvp or vjp, following examples in the Autodiff cookbook. I did this because I thought maybe it was not the Jacobian itself that was taking a lot of memory, but the process of constructing the Jacobian, and perhaps column-by-column or row-by-row construction would not use as much memory. Either approach solved my problem just fine, but didn't seem to use any less memory than the full Jacobian.
Here is a simplified snippet of code to show how I implement the jvp version of this (I do not show the vjp version);
f
,x
, andf_at_x
are the same as above, andmyx0
is a simple starting point that may not be near the true solution:As mentioned, this works fine on all problems, but as far as I can tell building the Jacobian column-by-column with the jvp approach does not appear to use less memory than if I computed the full Jacobian with
jacfwd
.b) I took advantage of scipy.sparse.linalg.LinearOperator (not jax.scipy...) and constructed a linear operator using jvp and vjp, which I passed to scipy.optimize.least_squares rather than using the full Jacobian. This is extremely fast and extremely low memory usage, and it works great on some problems. However, when using a linear operator, scipy.optimize.least_squares does not allow you to use the variable scaling option x_scale='jac'; this turns out to be a crucial option for ill-behaved versions of my problem, which is most of them. Scipy allows you to construct your own scaling factors as an alternative, but I do not know a way to construct good scaling factors for my problem's variables. Furthermore, even if I did, they must be fixed scaling factors that do not change from iteration to iteration, whereas x_scale='jac' has factors that change during the course of the solution.
As a result, I currently think I do not have a viable low-memory solution using the least-squares minimization approach with scipy.optimize.least_squares.
Here is the simplified implementation of this approach:
As noted, this is fast AND uses little memory. But because we cannot use the scipy...least_squares option of x_scale='jac' it produces bad results on difficult problems.
All you have to do here is pass the objective function that computes the SSR to jax.scipy.optimize.minimize (the jax.scipy version, not the scipy version) and it does whatever it does under the hood, using the BFGS algorithm. I figured jax probably would use memory-saving methods such as the Jacobian vector product (or perhaps it is called the Hessian vector product in this case). However, the documentation is spartan and there seem to be few options. Although I don't know exactly what this method is doing, it works great on well-behaved problems -- it is super fast, produces better objective function values than some other methods, and uses almost no memory. Unfortunately, it seems to fail as soon as I try it on real-world problems of modest size (e.g., 1,000 variables). Because it does not return a lot of information I don't know exactly how it fails -- I can see that when it fails it usually does no more than 40 function evaluations and it does only 1 major iteration. It returns a success code of 3, which clearly indicates failure but I am not sure what, precisely, it means. In any event, at the moment, I am ruling this approach out as a way to keep memory usage low.
Here is simplified code for this approach. I leave out the BFGS options I used, which didn't seem to have any effect:
As noted, this is fast and uses low memory, and is correct for easy problems. However, it fails for indeterminate reasons (the returned success code is 3) for more-difficult problems that can easily be solved using the full Jacobian with scipy's least_squares and the x_scale='jac' option.
Seeking advice
I am not completely out of ideas for how to solve larger versions of this problem while using relatively little memory, but I am close to out of ideas. I would much appreciate any advice on how this might be done.
Many thanks in advance.
Beta Was this translation helpful? Give feedback.
All reactions