-
Hi all, I would like to train my GP (DKL) model using an auxiliary model. For this reason, I have implemented a VAE composed of a DKL encoder (deterministic NN-based encoder + GP over latent features) and a decoder (NN-based). The idea is to train decoder, GP and encoder end-to-end by using the standard loss for VAEs. See code example --> However, while the NNs are training, the GP hyperparameters and variational parameters remain constant and they do not update. For sampling from the GP, I either tried the standard reparametrization trick and rsample(), but nothing really changed. Could you help me with this issue? Thank you in advance! Kind regards, Nicolò |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
Dear all, Any update on this issue? Kind regards, Nicolò |
Beta Was this translation helpful? Give feedback.
-
Hi, sorry for the slow reply. I'll try to take a look by the end of this week -- I've had success doing this type of joint modeling in the past so it's possible it's a straightforward fix. |
Beta Was this translation helpful? Give feedback.
-
So I took a look at what the issue is and I'm fairly confident that the issue is inside of your loss function where you're never actually using anything that is going to touch the forwards distribution of the variational GP. For the Thus, to train the GP hypers you need to either a) use a different variational strategy or b) to refactor your loss to include the KL(q||p) term between the approximate GP posterior and the GP prior like: kl_term = model.gp_layer.variational_strategy.kl_divergence() / 50000 # scaling should be by the total number of data pts
loss = loss_function(recon_batch, data, mean, variance)
loss += kl_term Hope that helps. |
Beta Was this translation helpful? Give feedback.
So I took a look at what the issue is and I'm fairly confident that the issue is inside of your loss function where you're never actually using anything that is going to touch the forwards distribution of the variational GP.
For the
GridVariationalStrategy
, your forwards pass is using solely the variational strategy (for a code ref see here ). This is pointed out in the appendix of the svdkl paper that describes the strategy (although really buried), see appendix b around eq. 12: https://arxiv.org/pdf/1611.00336.pdf.Thus, to train the GP hypers you need to either a) use a different variational strategy or b) to refactor your loss to include the KL(q||p) term between the approximate GP po…