-
Notifications
You must be signed in to change notification settings - Fork 578
reduce linear operator overhead in exact marginal log likelihood computation #2682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
cc @saitcakmak who has done some investigations in the past into removing LinearOperator overhead. I'll let him give the details, but IIRC the high level tl;dr was that while we did see speedups those didn't amount to dramatic savings in the grand scheme of things for our typical use cases. But our use cases aren't necessarily standard so I think there could be meaningful value in allowing to short-circuit the linear operator overhead in the small data regime. |
|
I think this is great! What I had done was to take out linear operator and any parts of GPyTorch that we didn't need for ExactGPs and create a bare-bone version of the library (~20% of GPyTorch and no linear operator). I had seen up to 2x faster execution for some model operations, but at the end of the day, it wasn't that significant when considered as part of the whole BO loop in Ax. I think introducing small (depending on where we measure from) improvements like this to core GPyTorch is great. I'd definitely use this for ExactGPs in BoTorch. |
|
I had a question while implementing this PR, which I think worth some discussions. Note that This PR does implement a custom backward pass for log determinant by In terms of the running time, this custom backward pass for logdet is indeed faster than PyTorch's default backward pass. Numerically, computing the inverse by two triangular solves shouldn't be bad since the backward pass of Cholesky decomposition requires two triangular solves anyway. (I am assuming PyTorch implements the Cholesky backward pass by something similar to Eq (9) in Iain Murray's note.) |
|
I am not sure if there were any special considerations at the time - @gpleiss or @jacobrgardner might know? |
|
I dug around and the only remotely related issue to cholesky derivatives I could find was this one pytorch/pytorch#18825. I think we just assumed the pytorch default derivatives for these ops would be fast. I certainly have no problem with us merging a custom backward pass in the name of 10-40% speed ups. |
| _default = True | ||
|
|
||
|
|
||
| class use_torch_tensors(_feature_flag): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do we think about making this on by default up to some N?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, the first version of this PR turns on this flag up to some N as you suggested. But the benchmark shows speed up even for N=1000 (whereas the default threshold for Cholesky decomposition is N=800). So I decided to turns this on as long as Cholesky decomposition is used for training and inference.
I think the design here is intertwined with your comments below---what would happen for larger N. I'll circle back on this once we have benchmark results for larger N.
|
@kayween One thing I'm noticing: the speedup is increasing with matrix size! That is unexpected: one would assume that the additional overhead of linear operator packaging becomes negligible as the matrix size increases. Can we extend the benchmark out to larger N? If there isn't some point where we stop seeing increasing speed-ups, there's obviously some serious problem with the default pytorch ops. |
I dug along this line a bit. @jacobrgardner and anyone who is curious, this is how the current PyTorch implements the Cholesky backward pass, which involves two triangular solves and a matmul. I haven't dug into |
|
@jacobrgardner Here are more benchmark results for larger N. The code generating the results is available in this gist. As with before, all benchmark runs use Cholesky decomposition for GP training. Running time (in seconds) on synthetic datasets
The time reduction doesn't seem to saturate as we increase the size of training data! For example, we're seeing about 61% time reduction on CPU when
I don't think the issue is necessarily with PyTorch ops. There are probably two explanations why the speed-up does not stop with larger dataset sizes.
|
@kayween this doesn't make sense to me. The representation tree at the end of the day should just be carrying around the underlying pytorch tensors as pointers and reconstructing the thin python wrapper. Like, in a vanilla GP, we have an All that is to say: there's no way that python overhead, which should essentially be constant, dominates the linear algebra being done by
Sure, this makes sense to me. |
|
@jacobrgardner This following code snippet shows that linear operators still have about 30% overhead compared to PyTorch even at the scale of |
Linear operators are great for large structured matrices. But it might incur overhead for moderate-sized dense matrices. This PR kick-starts exploring the headroom of reducing linear operator overhead.
What's Changed?
Wrap the model training code in the following context manager. Then, the exact log marginal likelihood is computed by doing linear algebra operations directly on torch tensors (as opposed to linear operators).
Under the hood, this PR modifes
MultivariateNormal.log_proband computes the log marginal likelihood by a custominv_quad_logdetimplementation that takes tensors as inputs.The exact GP prediction strategy is not modified yet. Thus, the test-time behavior is not affected.
Benchmark Model Fitting Time
The numbers in the following tables are obtained from this notebook.
The runtime improvement seems consistent. But the runtime reduction varies across different settings. The most significant speed up happens with double precision on CPUs, where the runtime reduction is up to 43% (which might be a bit surprising). Meanwhile, the improvement with single precision on GPUs is less dramatic.
GP Prediction Performance
The MAE of the trained GPs on the synthetic data is virtually the same. So this PR does not seem to regress model performance.