Replies: 1 comment 4 replies
-
It looks like you are missing a global readout step. By default, the basic GNNs we provide will return node-level representations. You can fix this, e.g., by running class MyGNN(torch.nn.Module):
def __init__(self, ...)
self.gnn = GCN(...)
self.classifier = Lin(...)
def forward(self, x, edge_index, batch):
x = self.gnn(x, edge_index)
x = global_mean_pool(x, batch)
return self.classifier(x) |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Dataset introduction
I'm trying to train pre-defined models to predict my dataset.
One sample is like
Data(x=[48, 1], edge_index=[2, 53], y=0)
, and there are 615 samples constructing a DataLoader likeDataBatch(x=[29520, 1], edge_index=[2, 32595], y=[615], batch=[29520], ptr=[616])
. Data list is used to construct the DataLoader as below.Pre-defined model I used
I tried
models.GCN(in_channels=data.num_node_features #1, out_channels=len(np.unique(y_train)) #53 )
, here 53 labels for my dataset.Problem encountered
I got the input shape
torch.Size([29520, 1])
and output shapetorch.Size([29520, 53])
. And loss functionnn.CrossEntropyLoss()
can not deal with this shape, reportingExpected input batch_size (29520) to match target batch_size (615)
. The correct output shape is supposed to betorch.Size(615, 53])
.User-defined model can work
In my self-defined model, the tensor shape is reshaped like
x=x.reshape([int(len(x)/48), 48*hidden_channels])
before output linear layer, so that input shapetorch.Size([29520, 1])
is transformed totorch.Size([615, 48*hidden_channels])
. The output layer shape is[48*hidden_channels, 53]
, and that's how I make it work.Question
How to make pre-defined model output the shape that the loss function can accept? Or, are other parts wrong? I can't figure out because the old-fashion way works. I really appreciate your advice.
Beta Was this translation helpful? Give feedback.
All reactions