@@ -13,7 +13,7 @@ The computational steps are the following:
13
13
14
14
```julia
15
15
m = compute_batch_message(l, g, x, e) # calls `compute_message`
16
- m̄ = aggregate_neighbors(l , aggr, g , m)
16
+ m̄ = aggregate_neighbors(g , aggr, m)
17
17
x′ = update_node(l, m̄, x)
18
18
e′ = update_edge(l, m, e)
19
19
```
63
63
64
64
function propagate (l, g:: GNNGraph , aggr, x, e= nothing )
65
65
m = compute_batch_message (l, g, x, e)
66
- m̄ = aggregate_neighbors (l, g, aggr, m)
66
+ m̄ = aggregate_neighbors (g, aggr, m)
67
67
x′ = update_node (l, m̄, x)
68
68
e′ = update_edge (l, m, e)
69
69
return x′, e′
74
74
"""
75
75
compute_message(l, x_i, x_j, [e_ij])
76
76
77
- Message function for the message-passing scheme,
78
- returning the message from node `j` to node `i` .
77
+ Message function for the message-passing scheme
78
+ started by [`propagate`](@ref).
79
+ Returns the message from node `j` to node `i` .
79
80
In the message-passing scheme, the incoming messages
80
81
from the neighborhood of `i` will later be aggregated
81
82
in order to update (see [`update_node`](@ref)) the features of node `i`.
82
83
83
84
The function operates on batches of edges, therefore
84
85
`x_i`, `x_j`, and `e_ij` are tensors whose last dimension
85
- is the batch size.
86
+ is the batch size, or can be tuple/namedtuples of
87
+ such tensors, according to the input to propagate.
86
88
87
89
By default, the function returns `x_j`.
88
90
Custom layer should specialize this method with the desired behavior.
@@ -106,7 +108,7 @@ _gather(x::Tuple, i) = map(x -> _gather(x, i), x)
106
108
_gather (x:: AbstractArray , i) = NNlib. gather (x, i)
107
109
_gather (x:: Nothing , i) = nothing
108
110
109
- function compute_batch_message (l, g, x, e)
111
+ function compute_batch_message (l, g:: GNNGraph , x, e)
110
112
s, t = edge_index (g)
111
113
xi = _gather (x, t)
112
114
xj = _gather (x, s)
@@ -121,12 +123,12 @@ _scatter(aggr, e::Tuple, t) = map(e -> _scatter(aggr, e, t), e)
121
123
_scatter (aggr, e:: AbstractArray , t) = NNlib. scatter (aggr, e, t)
122
124
_scatter (aggr, e:: Nothing , t) = nothing
123
125
124
- function aggregate_neighbors (l, g , aggr, e)
126
+ function aggregate_neighbors (g :: GNNGraph , aggr, e)
125
127
s, t = edge_index (g)
126
128
_scatter (aggr, e, t)
127
129
end
128
130
129
- aggregate_neighbors (l, g , aggr:: Nothing , e) = nothing
131
+ aggregate_neighbors (g :: GNNGraph , aggr:: Nothing , e) = nothing
130
132
131
133
# # Step 3
132
134
@@ -137,6 +139,9 @@ Node update function for the GNN layer `l`,
137
139
returning a new set of node features `x′` based on old
138
140
features `x` and the aggregated message `m̄` from the neighborhood.
139
141
142
+ The input `m̄` is an array, a tuple or a named tuple,
143
+ reflecting the output of [`compute_message`](@ref).
144
+
140
145
By default, the function returns `m̄`.
141
146
Custom layers should specialize this method with the desired behavior.
142
147
@@ -155,7 +160,7 @@ function update_node end
155
160
Edge update function for the GNN layer `l`,
156
161
returning a new set of edge features `e′` based on old
157
162
features `e` and the newly computed messages `m`
158
- from the [`message `](@ref) function.
163
+ from the [`compute_message `](@ref) function.
159
164
160
165
By default, the function returns `e`.
161
166
Custom layers should specialize this method with the desired behavior.
0 commit comments