Question about training heterogenous GNN model with PyTorch Lightning (to_hetero changes the type of the model) #7371
-
Hello! I have a GNN-based model for heterogeneous graph classification. The model is created with homogenous GNN layers and is then transformed into a heterogenous one with the to_hetero function.
The dataset is an instance of the HeteroData that follows the structure provided here. To enable mini-batch training NeighborLoader instances are created.
When I train/test the model with a standard PyTorch training loop the process continues without error and the model is able to train. However, using PyTorch Lightning to train/test is not possible because the to_hetero function transforms the model back to GraphModule while the PyTorch Lightning Trainer expects an instance of LightningModule. How can I train heterogeneous GNN models with PyTorch Lightning? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
We have several examples of using lightning for pyg models here. You might have to wrap you model and use some other interfaces Pyg provides to use pytroch-lightning with pyg. |
Beta Was this translation helpful? Give feedback.
-
@mtoshevska I haven't tried this yet, but simply applying class YourLightningModule(pl.LightningModule):
def __init__(self, ...):
...
self.model = to_hetero(GraphSAGE(...), ...) |
Beta Was this translation helpful? Give feedback.
@mtoshevska I haven't tried this yet, but simply applying
to_hetero
to your pure PyTorch model instead of yourLightningModule
might resolve your issue: