Is there some way to speed up my code in pyg library #7337
-
Dear PyG community, My task needs calculate node similarity between two graphs. Given node features My implement is a Triple Cycle, it‘s very slow(20 mins per epoch -->1h20mins).Is there a way to speed up my code? Here are my code: def forward(self, graph1, graph2):
g, mask = to_dense_batch(graph1.x, graph1.batch)
mx_num = g.size()[1]
cross_atten = torch.zeros((8, mx_num, mx_num)).to(g.device)
for idp in range(graph1.x.size()[0]):
for idq in range(graph2.x.size()[0]):
if graph1.batch[idp] != graph2.batch[idq]:
continue
b = graph1.batch[idp]
idx1 = idp - graph1.ptr[graph1.batch[idp]]
idx2 = idq - graph2.ptr[graph2.batch[idq]]
cross_atten[b][idx1][idx2] = self.neiMat(idp, idq, graph1, graph2, self.beta)
def neiMat(self, e1, e2, graph1, graph2, beta):
'''
: param e1: node_id in graph1
: param e2: node_id in graph2
'''
subset1, _, _, _ = k_hop_subgraph(e1, 1, graph1.edge_index)
subset2, _, _, _ = k_hop_subgraph(e2, 1, graph2.edge_index)
center1, neigh1 = subset1[0], subset1[1:]
center2, neigh2 = subset2[0], subset2[1:]
center_feat1, neigh_feat1 = graph1.x[center1], graph1.x[neigh1] # (1, dim) (n, dim)
center_feat2, neigh_feat2 = graph2.x[center2], graph2.x[neigh2] # (1, dim) (m, dim)
# aggregation weight
agg_wei1 = torch.matmul(center_feat1, torch.matmul(self.weight_matrix, neigh_feat1.transpose(-1, -2)))
agg_wei1 = torch.softmax(agg_wei1, dim=-1).unsqueeze(dim=0).transpose(-1, 0) # (n, 1)
agg_wei2 = torch.matmul(center_feat2, torch.matmul(self.weight_matrix, neigh_feat2.transpose(-1, -2)))
agg_wei2 = torch.softmax(agg_wei2, dim=-1).unsqueeze(dim=0).transpose(-1, 0) # (m, 1)
# neighboor matchong score
neigh_matching = torch.matmul(neigh_feat1, neigh_feat2.transpose(-1, -2)) # (n, m)
match_p2q = torch.softmax(neigh_matching, dim = -1)
match_q2p = torch.softmax(neigh_matching, dim = 0)
message_p = neigh_feat1 - torch.matmul(match_p2q, neigh_feat2) # (n, dim)
message_q = neigh_feat2 - torch.matmul(match_q2p.transpose(-1, -2), neigh_feat1) # (m, dim)
neigh_feat1 = torch.cat([neigh_feat1, message_p * beta], dim = -1) # (n, 2*dim)
neigh_feat2 = torch.cat([neigh_feat2, message_q * beta], dim = -1) # (m, 2*dim)
neigh_sum1 = (agg_wei1 * self.act(self.gate(neigh_feat1))).sum(dim = 0)
g1 = self.N(neigh_sum1)
neigh_sum2 = (agg_wei2 * self.act(self.gate(neigh_feat2))).sum(dim = 0)
g2 = self.N(neigh_sum2)
center1 = torch.cat([center_feat1, g1], dim = -1)
center2 = torch.cat([center_feat2, g2], dim = -1)
return torch.exp(-torch.linalg.norm(center1-center2)) # exp(-dis)` Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Thanks for the issue. I think this is a bit hard to say. I would start with doing some timing/benchmarking efforts to see which call is the bottleneck. Obviously, the double for-loop is not ideal, so maybe there exists some ways to parallelize that/compute in batches? |
Beta Was this translation helpful? Give feedback.
The idea here would be to compute similarity in batches/at once, e.g.,
sim
would be a tensor of shape[num_nodes_x, num_nodes_y, max_neighbors_in_x, max_neighbors_in_y]
. This would basically remove the need to do this computation sequentially. For creating this tensor, theutils.to_dense_batch
function may be useful, to create neighborhood features to shape[num_nodes_x, max_neighbors_in_x, num_features]
.