-
Notifications
You must be signed in to change notification settings - Fork 578
Description
π Feature Request
Accelerate validation prediction during training by reusing the solve from CG/Cholesky.
Motivation
Is your feature request related to a problem? Please describe.
The exact marginal log likelihood is not always guaranteed to combat overfitting, and as such, one might want to monitor the performance on a hold-out dataset to perform early stopping. Currently, this requires to do a fully separate prediction step, which might be computationally expensive.
Pitch
It would be nice to re-use computations of the log likelihood calculation. Parts which could be re-used (please complete this list with your ideas!):
- Training covariance matrix
- Solve of the training covariance matrix against the observed values (
$a = K^{-1}y$ )
Describe the solution you'd like
- Extend
LinearOperator.inv_quad_logdet()to also return the matrix solve without gradients (ctx.mark_non_differentiable). - Extend the forward pass in
ExactGPto take additional validation data (or pass this data in the constructor). In the forward pass, the only additional cost is the validation prior evaluation (mean & covar) and the matmul with the solve.
Describe alternatives you've considered
Somehow cache the relevant values using @cached. I like the explicit approach more, as caching always comes with the question when to acquire or release the cache, which can make memory usage unpredictable.
Are you willing to open a pull request?
I actually have a working implementation for this for a special case, However, I still want to take care of #2288 first, and this proposal might need some design discussions, as this definitely will change the API of some often used functions.
Additional context
It might be possible to extend this idea from exact to variational GPs. Of course, this approach only allows for the calculation of the mean prediction without a variance estimate in its current form.