Skip to content
Discussion options

You must be logged in to vote

Is this

import torch
from torch_geometric.utils import softmax
from torch_geometric.nn.pool import global_mean_pool
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

class Model(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        # other vars
        
    def forward(self,x, batch):
        # something useful here
        att = softmax(x, index=batch)
        y = global_mean_pool(x*att, batch=batch)
        return y

dataset = [ Data(x=torch.Tensor([ [1,1],[1, 1.1],[1,0.9] ])), \
            Data(x=torch.Tensor([ [1,1],[1, 1.1] ])) ]
loader = DataLoader(dataset, batch_size=2)
model = Model()
for step, data in enumerate(loade…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@fedeotto
Comment options

Answer selected by fedeotto
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants