Skip to content

[Doc] Update ResGatedGraphConv documentation to include edge feature handlingΒ #10578

@aditya0by0

Description

@aditya0by0

πŸ“š Describe the documentation issue

Hi PyG team,

Currently, the ResGatedGraphConv documentation shows the equations:

$$x'_i = W_1 x_i + \sum_{j \in \mathcal{N}(i)} \eta_{ij} \odot W_2 x_j$$

with

$$\eta_{ij} = \sigma(W_3 x_i + W_4 x_j)$$

This is accurate when no edge features are provided. However, in the implementation, when edge_attr is passed, it is concatenated to the node features before computing the key, query, and value projections:

if edge_attr is not None:
    k_i = self.lin_key(torch.cat([k_i, edge_attr], dim=-1))
    q_j = self.lin_query(torch.cat([q_j, edge_attr], dim=-1))
    v_j = self.lin_value(torch.cat([v_j, edge_attr], dim=-1))

Thus, the actual behavior is closer to:

$$x'_i = W_1 x_i + \sum_{j \in \mathcal{N}(i)} \eta_{ij} \odot (W_2 x_j + W_5 e_{ij})$$

$$\eta_{ij} = \sigma(W_3 x_i + W_4 x_j + W_6 e_{ij})$$

where ($e_{ij}$) are the edge features.

Suggest a potential alternative/fix

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions