Skip to content
Discussion options

You must be logged in to vote

Sorry for the slow reply @flcello . I imagine that the best thing to do is to not do things in batch mode, but to instead have a list of models on the M plate. Something like:

class M_Plate_GP(gpytorch.models.ApproximateGP):
    def __init__(self, num_inducing=64, name_prefix="mixture_gp"):
        # ...

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)


class Model(gpytorch.Module):
    def __init__(self):
        self.m_plate_gps = []
        for i in range(m):
            self.m_palte_gps.append(M_Plate_GP(...))
            
    def guide(self):
        for i in range(m

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@florinwalter
Comment options

Answer selected by florinwalter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants