2
2
# "Relational inductive biases, deep learning, and graph networks"
3
3
4
4
"""
5
- propagate(mp, g, X, E, U, aggr)
5
+ propagate(mp, g, aggr, [X, E, U]) -> X′, E′, U′
6
+ propagate(mp, g, aggr) -> g′
6
7
7
8
Perform the sequence of operations implementing the message-passing scheme
8
- on graph `g` with convolution layer `mp` .
9
+ of gnn layer `mp` on graph `g` .
9
10
Updates the node, edge, and global features `X`, `E`, and `U` respectively.
10
11
11
12
The computation involved is the following:
@@ -14,7 +15,7 @@ The computation involved is the following:
14
15
M = compute_batch_message(mp, g, X, E, U)
15
16
M̄ = aggregate_neighbors(mp, aggr, g, M)
16
17
X′ = update(mp, X, M̄, U)
17
- E′ = update_edge(mp, M, E , U)
18
+ E′ = update_edge(mp, E, M , U)
18
19
U′ = update_global(mp, U, X′, E′)
19
20
```
20
21
@@ -25,7 +26,7 @@ this method in the forward pass:
25
26
```julia
26
27
function (l::MyLayer)(g, X)
27
28
... some prepocessing if needed ...
28
- propagate(l, g, X, E, U, + )
29
+ propagate(l, g, +, X, E, U)
29
30
end
30
31
```
31
32
@@ -34,28 +35,28 @@ See also [`message`](@ref) and [`update`](@ref).
34
35
function propagate end
35
36
36
37
function propagate (mp, g:: GNNGraph , aggr)
37
- X, E, U = propagate (mp, g,
38
- node_features (g), edge_features (g), global_features (g),
39
- aggr)
40
- GNNGraph (g, ndata= X, edata= E, gdata= U)
38
+ X, E, U = propagate (mp, g, aggr,
39
+ node_features (g), edge_features (g), global_features (g))
40
+
41
+ return GNNGraph (g, ndata= X, edata= E, gdata= U)
41
42
end
42
43
43
- function propagate (mp, g:: GNNGraph , E , X, U, aggr )
44
+ function propagate (mp, g:: GNNGraph , aggr , X, E = nothing , U = nothing )
44
45
# TODO consider g.graph_indicator in propagating U
45
- M = compute_batch_message (mp, g, E, X , U)
46
- E = update_edge (mp, M, E, U )
47
- M̄ = aggregate_neighbors (mp, aggr, g, M )
48
- X = update (mp, M̄, X , U)
49
- U = update_global (mp, E , X, U )
50
- return E, X , U
46
+ M = compute_batch_message (mp, g, X, E , U)
47
+ M̄ = aggregate_neighbors (mp, g, aggr, M )
48
+ X′ = update (mp, X, M̄, U )
49
+ E′ = update_edge (mp, E, M , U)
50
+ U′ = update_global (mp, U , X′, E′ )
51
+ return X′, E′ , U′
51
52
end
52
53
53
54
"""
54
55
message(mp, x_i, x_j, [e_ij, u])
55
56
56
57
Message function for the message-passing scheme,
57
58
returning the message from node `j` to node `i` .
58
- In the message-passing scheme. the incoming messages
59
+ In the message-passing scheme, the incoming messages
59
60
from the neighborhood of `i` will later be aggregated
60
61
in order to [`update`](@ref) the features of node `i`.
61
62
@@ -75,7 +76,7 @@ See also [`update`](@ref) and [`propagate`](@ref).
75
76
function message end
76
77
77
78
"""
78
- update(mp, m̄, x , [u])
79
+ update(mp, x, m̄ , [u])
79
80
80
81
Update function for the message-passing scheme,
81
82
returning a new set of node features `x′` based on old
@@ -96,47 +97,45 @@ See also [`message`](@ref) and [`propagate`](@ref).
96
97
"""
97
98
function update end
98
99
99
-
100
100
_gather (x, i) = NNlib. gather (x, i)
101
101
_gather (x:: Nothing , i) = nothing
102
102
103
103
# # Step 1.
104
104
105
- function compute_batch_message (mp, g, E, X, u )
105
+ function compute_batch_message (mp, g, X, E, U )
106
106
s, t = edge_index (g)
107
107
Xi = _gather (X, t)
108
108
Xj = _gather (X, s)
109
- M = message (mp, Xi, Xj, E, u )
109
+ M = message (mp, Xi, Xj, E, U )
110
110
return M
111
111
end
112
112
113
- # @inline message(mp, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future
114
113
@inline message (mp, x_i, x_j, e_ij, u) = message (mp, x_i, x_j, e_ij)
115
114
@inline message (mp, x_i, x_j, e_ij) = message (mp, x_i, x_j)
116
115
@inline message (mp, x_i, x_j) = x_j
117
116
118
- # # Step 2
119
-
120
- @inline update_edge (mp, M, E, u) = update_edge (mp, M, E)
121
- @inline update_edge (mp, M, E) = E
122
-
123
- # # Step 3
117
+ # # Step 2
124
118
125
- function aggregate_neighbors (mp, aggr, g , E)
119
+ function aggregate_neighbors (mp, g, aggr , E)
126
120
s, t = edge_index (g)
127
121
NNlib. scatter (aggr, E, t)
128
122
end
129
123
130
- aggregate_neighbors (mp, aggr:: Nothing , g, E) = nothing
124
+ aggregate_neighbors (mp, g, aggr:: Nothing , E) = nothing
125
+
126
+ # # Step 3
127
+
128
+ @inline update (mp, x, m̄, u) = update (mp, x, m̄)
129
+ @inline update (mp, x, m̄) = m̄
131
130
132
131
# # Step 4
133
132
134
- # @inline update(mp, i, m̄, x, u) = update(mp, m, x, u)
135
- @inline update (mp, m̄, x, u) = update (mp, m̄, x)
136
- @inline update (mp, m̄, x) = m̄
133
+ @inline update_edge (mp, E, M, U) = update_edge (mp, E, M)
134
+ @inline update_edge (mp, E, M) = E
137
135
138
136
# # Step 5
139
137
140
- @inline update_global (mp, E, X, u) = u
138
+ @inline update_global (mp, U, X, E) = update_global (mp, U, X)
139
+ @inline update_global (mp, U, X) = U
141
140
142
141
# ## end steps ###
0 commit comments