-
Hello, thanks for this amazing project! After using GraphGym to train models I would like to get the node embeddings and further analyze them. Normally loading a Pytorch model would be possible by: How is it possible to load the saved GraphGym model (how is TheModelClass() initialized?) and use it to generate the node embeddings? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
It would be model = create_model()
model.load_state_dict(torch.load(PATH)) (see https://github.com/pyg-team/pytorch_geometric/blob/master/graphgym/main.py). |
Beta Was this translation helpful? Give feedback.
Good question. I think the most elegant solution to this is to register a
forward_hook
from PyTorch on top of your model, which lets you get the embeddings prior to the model head. I think this is the best solution since one usually does not want to modify the model output for these cases.