-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
I had some trouble running chebai-graph with torch_geometric version 2.3.1.
When loading a checkpoint trained on 2.6.1, I got a mismatch in the dimensions of the ResGatedGraphConv layers (this one.
Apparently, in version 2.3.1, the shape of the tensor only depends on the in_channel, while in later versions, in_channels and edge_attr are added.
Here is the output for in_channels=158, edge_attr=7:
RuntimeError: Error(s) in loading state_dict for ResGatedGraphConvNetGraphPred:
size mismatch for gnn.convs.0.lin_key.weight: copying a param with shape torch.Size([256, 165]) from checkpoint, the shape in current model is torch.Size([256, 158]).
size mismatch for gnn.convs.0.lin_query.weight: copying a param with shape torch.Size([256, 165]) from checkpoint, the shape in current model is torch.Size([256, 158]).
size mismatch for gnn.convs.0.lin_value.weight: copying a param with shape torch.Size([256, 165]) from checkpoint, the shape in current model is torch.Size([256, 158]).
size mismatch for gnn.convs.1.lin_key.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for gnn.convs.1.lin_query.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for gnn.convs.1.lin_value.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for gnn.convs.2.lin_key.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for gnn.convs.2.lin_query.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for gnn.convs.2.lin_value.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for gnn.convs.3.lin_key.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for gnn.convs.3.lin_query.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for gnn.convs.3.lin_value.weight: copying a param with shape torch.Size([256, 263]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for gnn.final_conv.lin_key.weight: copying a param with shape torch.Size([512, 263]) from checkpoint, the shape in current model is torch.Size([512, 256]).
size mismatch for gnn.final_conv.lin_query.weight: copying a param with shape torch.Size([512, 263]) from checkpoint, the shape in current model is torch.Size([512, 256]).
size mismatch for gnn.final_conv.lin_value.weight: copying a param with shape torch.Size([512, 263]) from checkpoint, the shape in current model is torch.Size([512, 256]).
Todo
- Require
torch_geometric>=2.4
Metadata
Metadata
Assignees
Labels
No labels