Skip to content
Discussion options

You must be logged in to vote

I think in this case it is best to out-source the embedding layer from the to_hetero call, e.g., via

class MyModel(torch.nn.Module):
    def __init__(self, ...):
         self.emb = Embedding(...)
         self.model = to_hetero(MyGNN(), metadata)
         
    def forward(self, x_dict, edge_index_dict):
         x_dict = copy.copy(x_dict)
         x_dict[key] = self.emb(x_dict[key])
         return self.model(x_dict, edge_index_dict)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@IskaNeumann
Comment options

Answer selected by IskaNeumann
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants