1
- # Adapted message passing from paper
2
- # "Relational inductive biases, deep learning, and graph networks"
3
-
4
1
"""
5
- propagate(l, g, aggr, [X, E ]) -> X ′, E ′
2
+ propagate(l, g, aggr, [x, e ]) -> x ′, e ′
6
3
propagate(l, g, aggr) -> g′
7
4
8
- Perform the sequence of operations implementing the message-passing scheme
9
- of gnn layer `l` on graph `g` .
10
- Updates the node, edge, and global features `X`, `E`, and `U` respectively.
5
+ Performs the message-passing for GNN layer `l` on graph `g` .
6
+ Returns updated node and edge features `x` and `e`.
7
+
8
+ In case no input and edge features are given as input,
9
+ extracts them from `g` and returns the same graph
10
+ with updated feautres.
11
11
12
- The computation involved is the following:
12
+ The computational steps are the following:
13
13
14
14
```julia
15
- M = compute_batch_message(l, g, X, E)
16
- M̄ = aggregate_neighbors(l, aggr, g, M )
17
- X ′ = update (l, X, M̄ )
18
- E ′ = update_edge(l, E, M )
15
+ m = compute_batch_message(l, g, x, e) # calls `message`
16
+ m̄ = aggregate_neighbors(l, aggr, g, m )
17
+ x ′ = update_node (l, m̄, x )
18
+ e ′ = update_edge(l, m, e )
19
19
```
20
20
21
- Custom layers typically define their own [`update `](@ref)
21
+ Custom layers typically define their own [`update_node `](@ref)
22
22
and [`message`](@ref) functions, then call
23
23
this method in the forward pass:
24
24
25
- ```julia
26
- function (l::MyLayer)(g, X)
27
- ... some prepocessing if needed ...
28
- propagate(l, g, +, X, E)
25
+ # Usage example
26
+
27
+ ```
28
+ using GraphNeuralNetworks, Flux
29
+
30
+ struct GNNConv <: GNNLayer
31
+ W
32
+ b
33
+ σ
34
+ end
35
+
36
+ Flux.@functor GNNConv
37
+
38
+ function GNNConv(ch::Pair{Int,Int}, σ=identity;
39
+ init=glorot_uniform, bias::Bool=true)
40
+ in, out = ch
41
+ W = init(out, in)
42
+ b = Flux.create_bias(W, bias, out)
43
+ GNNConv(W, b, σ, aggr)
44
+ end
45
+
46
+ message(l::GNNConv, x_i, x_j, e_ij) = l.W * x_j
47
+ update_node(l::GNNConv, m̄, x) = l.σ.(m̄ .+ l.bias)
48
+
49
+ function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix)
50
+ x, _ = propagate(l, g, +, x)
51
+ return x
29
52
end
30
53
```
31
54
32
- See also [`message`](@ref) and [`update `](@ref).
55
+ See also [`message`](@ref) and [`update_node `](@ref).
33
56
"""
34
57
function propagate end
35
58
36
59
function propagate (l, g:: GNNGraph , aggr)
37
- X, E = propagate (l, g, aggr, node_features (g), edge_features (g))
38
-
39
- return GNNGraph (g, ndata= X, edata= E)
60
+ x, e = propagate (l, g, aggr, node_features (g), edge_features (g))
61
+ return GNNGraph (g, ndata= x, edata= e)
40
62
end
41
63
42
- function propagate (l, g:: GNNGraph , aggr, X, E = nothing )
43
- M = compute_batch_message (l, g, X, E )
44
- M̄ = aggregate_neighbors (l, g, aggr, M )
45
- X ′ = update (l, X, M̄ )
46
- E ′ = update_edge (l, E, M )
47
- return X ′, E′, U ′
64
+ function propagate (l, g:: GNNGraph , aggr, x, e = nothing )
65
+ m = compute_batch_message (l, g, x, e )
66
+ m̄ = aggregate_neighbors (l, g, aggr, m )
67
+ x ′ = update_node (l, m̄, x )
68
+ e ′ = update_edge (l, m, e )
69
+ return x ′, e ′
48
70
end
49
71
72
+ # # Step 1.
73
+
50
74
"""
51
75
message(l, x_i, x_j, [e_ij])
52
76
53
77
Message function for the message-passing scheme,
54
78
returning the message from node `j` to node `i` .
55
79
In the message-passing scheme, the incoming messages
56
80
from the neighborhood of `i` will later be aggregated
57
- in order to [`update`](@ref) the features of node `i`.
81
+ in order to update (see [`update_node`](@ref)) the features of node `i`.
82
+
83
+ The function operates on batches of edges, therefore
84
+ `x_i`, `x_j`, and `e_ij` are tensors whose last dimention
85
+ is the batch size.
58
86
59
87
By default, the function returns `x_j`.
60
88
Custom layer should specialize this method with the desired behavior.
@@ -66,63 +94,69 @@ Custom layer should specialize this method with the desired behavior.
66
94
- `x_j`: Features of the neighbor `j` of node `i`.
67
95
- `e_ij`: Features of edge `(i,j)`.
68
96
69
- See also [`update `](@ref) and [`propagate`](@ref).
97
+ See also [`update_node `](@ref) and [`propagate`](@ref).
70
98
"""
71
99
function message end
72
100
73
- """
74
- update(l, x, m̄)
75
-
76
- Update function for the message-passing scheme,
77
- returning a new set of node features `x′` based on old
78
- features `x` and the incoming message from the neighborhood
79
- aggregation `m̄`.
80
-
81
- By default, the function returns `m̄`.
82
- Custom layers should specialize this method with the desired behavior.
83
-
84
- # Arguments
85
-
86
- - `l`: A gnn layer.
87
- - `m̄`: Aggregated edge messages from the [`message`](@ref) function.
88
- - `x`: Node features to be updated.
89
- - `u`: Global features.
90
-
91
- See also [`message`](@ref) and [`propagate`](@ref).
92
- """
93
- function update end
101
+ @inline message (l, x_i, x_j, e_ij) = message (l, x_i, x_j)
102
+ @inline message (l, x_i, x_j) = x_j
94
103
95
104
_gather (x, i) = NNlib. gather (x, i)
96
105
_gather (x:: Nothing , i) = nothing
97
106
98
- # # Step 1.
99
-
100
- function compute_batch_message (l, g, X, E)
107
+ function compute_batch_message (l, g, x, e)
101
108
s, t = edge_index (g)
102
- Xi = _gather (X , t)
103
- Xj = _gather (X , s)
104
- M = message (l, Xi, Xj, E )
105
- return M
109
+ xi = _gather (x , t)
110
+ xj = _gather (x , s)
111
+ m = message (l, xi, xj, e )
112
+ return m
106
113
end
107
114
108
- @inline message (l, x_i, x_j, e_ij) = message (l, x_i, x_j)
109
- @inline message (l, x_i, x_j) = x_j
110
-
111
115
# # Step 2
112
116
113
- function aggregate_neighbors (l, g, aggr, E )
117
+ function aggregate_neighbors (l, g, aggr, e )
114
118
s, t = edge_index (g)
115
- NNlib. scatter (aggr, E , t)
119
+ NNlib. scatter (aggr, e , t)
116
120
end
117
121
118
- aggregate_neighbors (l, g, aggr:: Nothing , E ) = nothing
122
+ aggregate_neighbors (l, g, aggr:: Nothing , e ) = nothing
119
123
120
124
# # Step 3
121
125
122
- @inline update (l, x, m̄) = m̄
126
+ """
127
+ update_node(l, m̄, x)
128
+
129
+ Node update function for the GNN layer `l`,
130
+ returning a new set of node features `x′` based on old
131
+ features `x` and the aggregated message `m̄` from the neighborhood.
132
+
133
+ By default, the function returns `m̄`.
134
+ Custom layers should specialize this method with the desired behavior.
135
+
136
+ See also [`message`](@ref), [`update_edge`](@ref), and [`propagate`](@ref).
137
+ """
138
+ function update_node end
139
+
140
+ @inline update_node (l, m̄, x) = m̄
123
141
124
142
# # Step 4
125
143
126
- @inline update_edge (l, E, M) = E
144
+
145
+ """
146
+ update_edge(l, m, e)
147
+
148
+ Edge update function for the GNN layer `l`,
149
+ returning a new set of edge features `e′` based on old
150
+ features `e` and the newly computed messages `m`
151
+ from the [`message`](@ref) function.
152
+
153
+ By default, the function returns `e`.
154
+ Custom layers should specialize this method with the desired behavior.
155
+
156
+ See also [`message`](@ref), [`update_node`](@ref), and [`propagate`](@ref).
157
+ """
158
+ function update_edge end
159
+
160
+ @inline update_edge (l, m, e) = e
127
161
128
162
# ## end steps ###
0 commit comments