How to get the derivative of an entry in a covariance matrix wrt a kernel parameter? #1756
-
I'm trying to get the derivative of an entry of a covariance matrix wrt a kernel parameter, and I'm having a bit of trouble after banging on it for a few hours. I've pasted a toy example below, but the results I get don't match my Mathematica sanity check (I expect the derivative to be 2.4176, but I'm getting 0.515859 with this code). I thought maybe I was doing something wrong with the raw/transformed parameters, but I've looked around in the code base and that doesn't seem to be it (though I could have totally misunderstood...), so I suspect I'm misunderstanding something basic about Torch, Gpytorch, or the interaction of the two. Thanks for any pointers! import torch
from gpytorch.kernels import RBFKernel
ls = 0.24
dist = 0.23
kernel = RBFKernel()
kernel.lengthscale = ls
X = torch.tensor([0, dist], requires_grad=True)
Gram = kernel(X, X).evaluate()
print(Gram)
k = Gram[0, 1]
print(k.item())
k.backward()
dkdl = kernel.get_parameter("raw_lengthscale").grad.item()
print(dkdl) Output:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You said you might be doing something wrong with raw/transformed parameters, and it seems to me like this is definitely the case. In your example Mathematica code, you're differentiating directly with respect to \ell, but in GPyTorch you're differentiating with respect to the inverse softplus of \ell (e.g., the raw lengthscale). In other words, GPyTorch's version of the RBF kernel is basically |
Beta Was this translation helpful? Give feedback.
You said you might be doing something wrong with raw/transformed parameters, and it seems to me like this is definitely the case. In your example Mathematica code, you're differentiating directly with respect to \ell, but in GPyTorch you're differentiating with respect to the inverse softplus of \ell (e.g., the raw lengthscale).
In other words, GPyTorch's version of the RBF kernel is basically
exp[-dist^2 / softplus(raw_ls)^2]
, and what you're computing is the derivative with respect toraw_ls
.