📚 Describe the documentation issue
Hi PyG team,
Currently, the ResGatedGraphConv documentation shows the equations:
$$x'_i = W_1 x_i + \sum_{j \in \mathcal{N}(i)} \eta_{ij} \odot W_2 x_j$$
with
$$\eta_{ij} = \sigma(W_3 x_i + W_4 x_j)$$
This is accurate when no edge features are provided. However, in the implementation, when edge_attr is passed, it is concatenated to the node features before computing the key, query, and value projections:
if edge_attr is not None:
k_i = self.lin_key(torch.cat([k_i, edge_attr], dim=-1))
q_j = self.lin_query(torch.cat([q_j, edge_attr], dim=-1))
v_j = self.lin_value(torch.cat([v_j, edge_attr], dim=-1))
Thus, the actual behavior is closer to:
$$x'_i = W_1 x_i + \sum_{j \in \mathcal{N}(i)} \eta_{ij} \odot (W_2 x_j + W_5 e_{ij})$$
$$\eta_{ij} = \sigma(W_3 x_i + W_4 x_j + W_6 e_{ij})$$
where ($e_{ij}$) are the edge features.
Suggest a potential alternative/fix
No response