Skip to content

Conversation

@kayween
Copy link
Collaborator

@kayween kayween commented Nov 19, 2025

Linear operators are great for large structured matrices. But it might incur overhead for moderate-sized dense matrices. This PR kick-starts exploring the headroom of reducing linear operator overhead.

What's Changed?

Wrap the model training code in the following context manager. Then, the exact log marginal likelihood is computed by doing linear algebra operations directly on torch tensors (as opposed to linear operators).

with settings.use_torch_tensors(True):
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()

Under the hood, this PR modifes MultivariateNormal.log_prob and computes the log marginal likelihood by a custom inv_quad_logdet implementation that takes tensors as inputs.

The exact GP prediction strategy is not modified yet. Thus, the test-time behavior is not affected.

Benchmark Model Fitting Time

The numbers in the following tables are obtained from this notebook.

  • All GP models are configured to use Cholesky decomposition for inference. No CG is involved.
  • The runtime in each table cell is timed over 200 replications.

The runtime improvement seems consistent. But the runtime reduction varies across different settings. The most significant speed up happens with double precision on CPUs, where the runtime reduction is up to 43% (which might be a bit surprising). Meanwhile, the improvement with single precision on GPUs is less dramatic.

model fitting running time (float32, CPU)
dataset size501005001000
main branch (s)13.815.344.6169.8
this PR (s)11.013.332.5110.8
runtime reduction-20.3%-13.1%-27.1%-34.7%
model fitting running time (float64, CPU)
dataset size501005001000
main branch (s)14.616.272.3342.4
this PR (s)12.314.148.4192.4
runtime reduction-15.8%-13.0%-33.1%-43.8%
model fitting running time (float32, GPU)
dataset size501005001000
main branch (s)19.417.219.827.3
this PR (s)17.216.118.323.6
runtime reduction-11.3%-6.4%-7.6%-13.6%
model fitting running time (float64, GPU)
dataset size501005001000
main branch (s)19.422.465.1184.9
this PR (s)17.420.256.5143.5
runtime reduction-10.3%-9.8%-13.2%-22.4%

GP Prediction Performance

The MAE of the trained GPs on the synthetic data is virtually the same. So this PR does not seem to regress model performance.

@kayween
Copy link
Collaborator Author

kayween commented Nov 19, 2025

cc @jacobrgardner @gpleiss @Balandat

@Balandat
Copy link
Collaborator

cc @saitcakmak who has done some investigations in the past into removing LinearOperator overhead. I'll let him give the details, but IIRC the high level tl;dr was that while we did see speedups those didn't amount to dramatic savings in the grand scheme of things for our typical use cases. But our use cases aren't necessarily standard so I think there could be meaningful value in allowing to short-circuit the linear operator overhead in the small data regime.

@saitcakmak
Copy link
Collaborator

I think this is great! What I had done was to take out linear operator and any parts of GPyTorch that we didn't need for ExactGPs and create a bare-bone version of the library (~20% of GPyTorch and no linear operator). I had seen up to 2x faster execution for some model operations, but at the end of the day, it wasn't that significant when considered as part of the whole BO loop in Ax. I think introducing small (depending on where we measure from) improvements like this to core GPyTorch is great. I'd definitely use this for ExactGPs in BoTorch.

@kayween
Copy link
Collaborator Author

kayween commented Nov 20, 2025

I had a question while implementing this PR, which I think worth some discussions. Note that CholLinearOperator does not implement a custom backward pass for the log determinant. Instead, it relies on PyTorch's default backward pass, which backprop through the Cholesky decomposition. Were there any special considerations there?

This PR does implement a custom backward pass for log determinant by d logdet(K) = K^{-1}, which involves two triangular solves with the cached Cholesky factor of K.

In terms of the running time, this custom backward pass for logdet is indeed faster than PyTorch's default backward pass. Numerically, computing the inverse by two triangular solves shouldn't be bad since the backward pass of Cholesky decomposition requires two triangular solves anyway.

(I am assuming PyTorch implements the Cholesky backward pass by something similar to Eq (9) in Iain Murray's note.)

@Balandat
Copy link
Collaborator

I am not sure if there were any special considerations at the time - @gpleiss or @jacobrgardner might know?

@jacobrgardner
Copy link
Member

I dug around and the only remotely related issue to cholesky derivatives I could find was this one pytorch/pytorch#18825.

I think we just assumed the pytorch default derivatives for these ops would be fast. I certainly have no problem with us merging a custom backward pass in the name of 10-40% speed ups.

_default = True


class use_torch_tensors(_feature_flag):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we think about making this on by default up to some N?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, the first version of this PR turns on this flag up to some N as you suggested. But the benchmark shows speed up even for N=1000 (whereas the default threshold for Cholesky decomposition is N=800). So I decided to turns this on as long as Cholesky decomposition is used for training and inference.

I think the design here is intertwined with your comments below---what would happen for larger N. I'll circle back on this once we have benchmark results for larger N.

@jacobrgardner
Copy link
Member

@kayween One thing I'm noticing: the speedup is increasing with matrix size! That is unexpected: one would assume that the additional overhead of linear operator packaging becomes negligible as the matrix size increases.

Can we extend the benchmark out to larger N? If there isn't some point where we stop seeing increasing speed-ups, there's obviously some serious problem with the default pytorch ops.

@kayween
Copy link
Collaborator Author

kayween commented Nov 22, 2025

I dug around and the only remotely related issue to Cholesky derivatives I could find was this one pytorch/pytorch#18825.

I dug along this line a bit. @jacobrgardner and anyone who is curious, this is how the current PyTorch implements the Cholesky backward pass, which involves two triangular solves and a matmul.

https://github.com/pytorch/pytorch/blob/a3cc252e03572835c15afde54b81fc5e8616ad27/torch/csrc/autograd/FunctionsManual.cpp#L1984-L2010

I haven't dug into cholesky_inverse too deep since it is a part of LAPACK. But I assume it's based on two triangular solves. Thus, the custom backward pass of logdet in this PR saves (at least) a matmul compared to backproping through Cholesky decomposition.

@kayween
Copy link
Collaborator Author

kayween commented Nov 25, 2025

@jacobrgardner Here are more benchmark results for larger N. The code generating the results is available in this gist. As with before, all benchmark runs use Cholesky decomposition for GP training.

Running time (in seconds) on synthetic datasets

dataset sizes CPU f32 CPU f64 GPU f32 GPU f64
2000 5.2 → 2.3 (-55.8%) 9.5 → 4.4 (-53.7%) 4.5 → 3.1 (-31.1%) 4.4 → 3.1 (-29.5%)
5000 57.7 → 24.3 (-57.9%) 110.3 → 47.2 (-57.2%) 2.3 → 1.8 (-21.7%) 57.7 → 34.2 (-40.7%)
10000 362.8 → 138.6 (-61.8%) 739.2 → 288.1 (-61.0%) 13.6 → 9.5 (-30.1%) 442.3 → 254.8 (-42.4%)

The time reduction doesn't seem to saturate as we increase the size of training data! For example, we're seeing about 61% time reduction on CPU when n = 10000.

If there isn't some point where we stop seeing increasing speed-ups, there's obviously some serious problem with the default pytorch ops.

I don't think the issue is necessarily with PyTorch ops. There are probably two explanations why the speed-up does not stop with larger dataset sizes.

  1. My hunch is that linear operators' representation tree implementation is not very efficient compared to PyTorch ops, which is where the overhead comes from. For example, InvQuad needs to reconstruct those linear operators in the forward and backward passes.
  2. This PR has a custom backward pass for logdet, which is not available in linear operators.

cc @gpleiss @Balandat @saitcakmak

@jacobrgardner
Copy link
Member

My hunch is that linear operators' representation tree implementation is not very efficient compared to PyTorch ops, which is where the overhead comes from. For example, InvQuad needs to reconstruct those linear operators in the forward and backward passes.

@kayween this doesn't make sense to me. The representation tree at the end of the day should just be carrying around the underlying pytorch tensors as pointers and reconstructing the thin python wrapper.

Like, in a vanilla GP, we have an AddedDiagLinearOperator(DenseLinearOperator(K), DiagLinearOperator(\sigma^2)). The representation tree for this has, at its roots, tensor(K) and tensor(\sigma^2). The pytorch Function takes those tensors as inputs, and just reconstructs AddedDiagLinearOperator(...) in the forward pass of the Function.

All that is to say: there's no way that python overhead, which should essentially be constant, dominates the linear algebra being done by n = 10,000. Or, more importantly, that overhead certainly shouldn't be growing with n!

This PR has a custom backward pass for logdet, which is not available in linear operators.

Sure, this makes sense to me.

@kayween
Copy link
Collaborator Author

kayween commented Nov 26, 2025

@jacobrgardner This following code snippet shows that linear operators still have about 30% overhead compared to PyTorch even at the scale of n = 10000. This overhead may not come from the representation tree, but it seems like it has to be somewhere inside the inv_quad function call.

import linear_operator
import torch

from gpytorch.kernels import RBFKernel

torch.manual_seed(42)

n = 10_000

train_x = torch.rand(n, 5)
train_y = torch.rand(n)


def compute_inv_quad(use_linop: bool):
    covar_module = RBFKernel()
    covar = covar_module(train_x).evaluate_kernel().add_jitter(1e-3)

    if use_linop:
        with linear_operator.settings.fast_computations(False, False, False):
            return covar.inv_quad(inv_quad_rhs=train_y.unsqueeze(-1))
    else:
        covar = covar.to_dense()
        chol = torch.linalg.cholesky(covar)

        inv_chol_rhs = torch.linalg.solve_triangular(chol, train_y.unsqueeze(-1), upper=False).squeeze(-1)
        return (inv_chol_rhs**2).sum(-1)


%timeit compute_inv_quad(use_linop=True)
%timeit compute_inv_quad(use_linop=False)

# outputs:
# 1.31 s ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 912 ms ± 2.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants