Skip to content
Discussion options

You must be logged in to vote

However, using PyTorch Lightning to train/test is not possible because the to_hetero function transforms the model back to GraphModule while the PyTorch Lightning Trainer expects an instance of LightningModule.

@mtoshevska I haven't tried this yet, but simply applying to_hetero to your pure PyTorch model instead of your LightningModule might resolve your issue:

class YourLightningModule(pl.LightningModule):
    def __init__(self, ...):
        ...
        self.model = to_hetero(GraphSAGE(...), ...)

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@mtoshevska
Comment options

Comment options

You must be logged in to vote
1 reply
@mtoshevska
Comment options

Answer selected by mtoshevska
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants