|
1 | 1 | """
|
2 | 2 | propagate(f, g, aggr; xi, xj, e) -> m̄
|
3 | 3 |
|
4 |
| -Performs the message passing scheme on graph `g`. |
5 |
| -Returns the aggregated node features `m̄` computed |
| 4 | +Performs message passing on graph `g`. |
6 | 5 |
|
7 |
| -The computational steps are the following: |
| 6 | +Takes care of materializing the node features on each edge, |
| 7 | +applying the message function, and returning an aggregated message ``\bar{\mathbf{m}}`` |
| 8 | +(depending on the return value of `f`, an array or a named tuple of |
| 9 | +arrays with last dimension's size `g.num_nodes`). |
| 10 | +
|
| 11 | +It can be decomposed in two steps: |
8 | 12 |
|
9 | 13 | ```julia
|
10 | 14 | m = apply_edges(f, g, xi, xj, e)
|
11 | 15 | m̄ = aggregate_neighbors(g, aggr, m)
|
12 | 16 | ```
|
13 | 17 |
|
14 |
| -GNN layers typically call propagate in their forward pass. |
| 18 | +GNN layers typically call `propagate` in their forward pass, |
| 19 | +providing as input `f` a closure. |
15 | 20 |
|
16 | 21 | # Arguments
|
17 | 22 |
|
| 23 | +- `g`: A `GNNGraph`. |
| 24 | +- `xi`: An array or a named tuple containing arrays whose last dimension's size |
| 25 | + is `g.num_nodes`. It will be appropriately materialized on the |
| 26 | + target node of each edge (see also [`edge_index`](@ref)). |
| 27 | +- `xj`: As `xj`, but to be materialized on edges' sources. |
| 28 | +- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. |
18 | 29 | - `f`: A generic function that will be passed over to [`apply_edges`](@ref).
|
19 |
| - Takes as inputs `xi`, `xj`, and `e` |
20 |
| - (target nodes' features, source nodes' features, and edge features |
21 |
| - respetively) and returns new edge features `m`. |
| 30 | + Has to take as inputs the edge-materialized `xi`, `xj`, and `e` |
| 31 | + (arrays or named tuples of arrays whose last dimension' size is the size of |
| 32 | + a batch of edges). Its output has to be an array or a named tuple of arrays |
| 33 | + with the same batch size. |
| 34 | +- `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`. |
22 | 35 |
|
23 |
| -# Usage example |
| 36 | +# Usage Examples |
24 | 37 |
|
25 | 38 | ```julia
|
26 | 39 | using GraphNeuralNetworks, Flux
|
|
68 | 81 | """
|
69 | 82 | apply_edges(f, xi, xj, e)
|
70 | 83 |
|
71 |
| -Message function for the message-passing scheme |
72 |
| -started by [`propagate`](@ref). |
73 | 84 | Returns the message from node `j` to node `i` .
|
74 | 85 | In the message-passing scheme, the incoming messages
|
75 | 86 | from the neighborhood of `i` will later be aggregated
|
76 | 87 | in order to update the features of node `i`.
|
77 | 88 |
|
78 | 89 | The function operates on batches of edges, therefore
|
79 | 90 | `xi`, `xj`, and `e` are tensors whose last dimension
|
80 |
| -is the batch size, or can be tuple/namedtuples of |
81 |
| -such tensors, according to the input to propagate. |
82 |
| -
|
83 |
| -By default, the function returns `xj`. |
84 |
| -Custom layer should specialize this method with the desired behavior. |
85 |
| -
|
| 91 | +is the batch size, or can be named tuples of |
| 92 | +such tensors. |
| 93 | + |
86 | 94 | # Arguments
|
87 | 95 |
|
88 |
| -- `f`: A function that takes as inputs `xi`, `xj`, and `e` |
89 |
| - (target nodes' features, source nodes' features, and edge features |
90 |
| - respetively) and returns new edge features `m`. |
91 |
| -- `xi`: Features of the central node `i`. |
92 |
| -- `xj`: Features of the neighbor `j` of node `i`. |
93 |
| -- `eij`: Features of edge `(i,j)`. |
| 96 | +- `g`: A `GNNGraph`. |
| 97 | +- `xi`: An array or a named tuple containing arrays whose last dimension's size |
| 98 | + is `g.num_nodes`. It will be appropriately materialized on the |
| 99 | + target node of each edge (see also [`edge_index`](@ref)). |
| 100 | +- `xj`: As `xj`, but to be materialized on edges' sources. |
| 101 | +- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. |
| 102 | +- `f`: A function that takes as inputs the edge-materialized `xi`, `xj`, and `e`. |
| 103 | + These are arrays (or named tuples of arrays) whose last dimension' size is the size of |
| 104 | + a batch of edges. The output of `f` has to be an array (or a named tuple of arrays) |
| 105 | + with the same batch size. |
94 | 106 |
|
95 | 107 | See also [`propagate`](@ref).
|
96 | 108 | """
|
|
0 commit comments