Different number of aggretation layers for each edge_type in HeteroData() #6471
-
I have a import torch
from torch_geometric.nn import GCNConv, GATv2Conv
# Initializa the GNN
in_features = 6
graph_conv_features = [64, 64, 64]
local_convs = torch.nn.ModuleList()
global_convs = torch.nn.ModuleList()
local_convs.append(HeteroConv({
('places', 'with', 'places'): GCNConv(in_features, graph_conv_features[0]),
('places', 'nearby_to', 'room'): GATv2Conv(in_features, graph_conv_features[0], add_self_loops=False),
}, aggr='sum'))
global_convs.append(HeteroConv({
('places', 'connected_by', 'room'): TransformerConv(graph_conv_features[-1], graph_conv_features[-1]),
}, aggr='sum'))
for i in range(1, len(graph_conv_features)-1):
local_convs.append(HeteroConv({
('places', 'with', 'places'): GCNConv(graph_conv_features[i - 1], graph_conv_features[i]),
('places', 'nearby_to', 'room'): GATv2Conv(graph_conv_features[i-1], graph_conv_features[i],
add_self_loops=False),
}, aggr='sum'))
# Forward, `x_dict` is the node feature dict of the `HeteroData()` object
for conv in local_convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: h.relu() for key, h in x_dict.items()}
for conv in global_convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: h.relu() for key, h in x_dict.items()} However I found that after forward the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Beta Was this translation helpful? Give feedback.
global_convs
only has a single GNN with message going fromplaces
toroom
, so it returns a dict with onlyroom
. The problem in the above code is that you are overwritingx_dict
. Instead try something like