-
Hi everyone, I really like torch_geometric. Mathias you've done a fantastic work !! I have some questions about my network. I am building and testing a network that works on surface meshes and does some regression tasks. I have built a network similar to the one in the [https://arxiv.org/abs/1706.05206](FeaStnet article), in which a series of graph convolutions are concatenated and used to create a UNet-like structure. I wanted to add some global features, so I added a My question is what do you think about it and if the Do you think I could add the global features in a different (and maybe smarter) way ? def forward(self, data):
x, edge_index, norm = data.pos, data.edge_index, data.norm
x = torch.cat((x, norm), dim=-1)
residuals = []
edge_indeces = []
clusters = []
for i, down in enumerate(self.down_path):
x = down(x, edge_index)
if i != len(self.down_path)-1:
residuals.append(x)
edge_indeces.append(edge_index)
cluster = self.graclus(edge_index)
data.x = x
data = max_pool(cluster, data)
x, edge_index = data.x, data.edge_index
clusters.append(cluster)
# concat global features
global_x = global_max_pool(x, data.batch).view(-1)
global_x = global_x.repeat(x.shape[0]).view(x.shape[0],global_x.shape[0])
x = torch.cat((x, global_x), dim=1)
for i, up in enumerate(self.up_path):
cluster, _ = self.consecutive_cluster(clusters[-i-1])
res = residuals[-i-1]
edge_index = edge_indeces[-i-1]
upsampled_x = x[cluster]
x = torch.cat((res, upsampled_x), dim=-1)
x = up(x, edge_index)
return self.last(x) p.s. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Not sure if the above code will work correctly for batch_size>1, because the output of this line So maybe check if you can use the below code instead.
|
Beta Was this translation helpful? Give feedback.
Not sure if the above code will work correctly for batch_size>1, because the output of this line
global_x = global_x.repeat(x.shape[0]).view(x.shape[0],global_x.shape[0])
should have a shape ofnum_nodes*(batches*num_feats)
so each nodes features would be concatenated with features from all graphs in the batch.So maybe check if you can use the below code instead.