Skip to content
Discussion options

You must be logged in to vote

If you want to learn the importance of nodes jointly with the prediction task, you could apply a soft masking approach at the beginning of your model:

self.mask = torch.nn.Parameter(torch.randn(num_nodes, 1))

def forward(self, x, edge_index, ...):
    x = self.mask.sigmoid() * x
    ...

Let me know if this is helpful to you :)
...

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@vctorwei
Comment options

@rusty1s
Comment options

@vctorwei
Comment options

@rusty1s
Comment options

@vctorwei
Comment options

Answer selected by vctorwei
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants