GPs on two plates with low-level Pyro interface #1744
-
Hi! I'm building a Pyro model with the low-level Pyro interface of GPyTorch. I have GPs f_km with scales s_km and lengthscales l_km on two plates (K and M). Now the evaluation points x_m differ between GPs on the M plate, but not between GPs on the K plate. I wonder about the best way to implement this. I tried to illustrate the relevant part of the model in a graph: |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Sorry for the slow reply @flcello . I imagine that the best thing to do is to not do things in batch mode, but to instead have a list of models on the class M_Plate_GP(gpytorch.models.ApproximateGP):
def __init__(self, num_inducing=64, name_prefix="mixture_gp"):
# ...
def forward(self, x):
mean = self.mean_module(x)
covar = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean, covar)
class Model(gpytorch.Module):
def __init__(self):
self.m_plate_gps = []
for i in range(m):
self.m_palte_gps.append(M_Plate_GP(...))
def guide(self):
for i in range(m):
self.m_palte_gps[I].pyro_guide()
# ...
def model(self):
for i in range(m):
self.m_palte_gps[I].pyro_guide()
# ... |
Beta Was this translation helpful? Give feedback.
Sorry for the slow reply @flcello . I imagine that the best thing to do is to not do things in batch mode, but to instead have a list of models on the
M
plate. Something like: