2
2
# "Relational inductive biases, deep learning, and graph networks"
3
3
4
4
"""
5
- propagate(mp, g::GNNGraph , aggr)
6
- propagate(mp, g::GNNGraph, E, X, u, aggr)
5
+ propagate(mp, g, aggr, [X, E, U]) -> X′, E′, U′
6
+ propagate(mp, g, aggr) -> g′
7
7
8
- Perform the sequence of operation implementing the message-passing scheme
9
- and updating node, edge, and global features `X`, `E`, and `u` respectively.
8
+ Perform the sequence of operations implementing the message-passing scheme
9
+ of gnn layer `mp` on graph `g` .
10
+ Updates the node, edge, and global features `X`, `E`, and `U` respectively.
10
11
11
12
The computation involved is the following:
12
13
13
14
```julia
14
- M = compute_batch_message(mp, g, E, X, u)
15
- E = update_edge(mp, M, E, u)
15
+ M = compute_batch_message(mp, g, X, E, U)
16
16
M̄ = aggregate_neighbors(mp, aggr, g, M)
17
- X = update(mp, M̄, X, u)
18
- u = update_global(mp, E, X, u)
17
+ X′ = update(mp, X, M̄, U)
18
+ E′ = update_edge(mp, E, M, U)
19
+ U′ = update_global(mp, U, X′, E′)
19
20
```
20
21
21
22
Custom layers typically define their own [`update`](@ref)
22
- and [`message`](@ref) function , then call
23
+ and [`message`](@ref) functions , then call
23
24
this method in the forward pass:
24
25
25
26
```julia
26
27
function (l::MyLayer)(g, X)
27
28
... some prepocessing if needed ...
28
- E = nothing
29
- u = nothing
30
- propagate(l, g, E, X, u, +)
29
+ propagate(l, g, +, X, E, U)
31
30
end
32
31
```
33
32
@@ -36,28 +35,28 @@ See also [`message`](@ref) and [`update`](@ref).
36
35
function propagate end
37
36
38
37
function propagate (mp, g:: GNNGraph , aggr)
39
- E, X , U = propagate (mp, g,
40
- edge_features (g), node_features (g), global_features (g),
41
- aggr)
42
- GNNGraph (g, ndata= X, edata= E, gdata= U)
38
+ X, E , U = propagate (mp, g, aggr,
39
+ node_features (g), edge_features (g), global_features (g))
40
+
41
+ return GNNGraph (g, ndata= X, edata= E, gdata= U)
43
42
end
44
43
45
- function propagate (mp, g:: GNNGraph , E , X, U, aggr )
44
+ function propagate (mp, g:: GNNGraph , aggr , X, E = nothing , U = nothing )
46
45
# TODO consider g.graph_indicator in propagating U
47
- M = compute_batch_message (mp, g, E, X , U)
48
- E = update_edge (mp, M, E, U )
49
- M̄ = aggregate_neighbors (mp, aggr, g, M )
50
- X = update (mp, M̄, X , U)
51
- U = update_global (mp, E , X, U )
52
- return E, X , U
46
+ M = compute_batch_message (mp, g, X, E , U)
47
+ M̄ = aggregate_neighbors (mp, g, aggr, M )
48
+ X′ = update (mp, X, M̄, U )
49
+ E′ = update_edge (mp, E, M , U)
50
+ U′ = update_global (mp, U , X′, E′ )
51
+ return X′, E′ , U′
53
52
end
54
53
55
54
"""
56
55
message(mp, x_i, x_j, [e_ij, u])
57
56
58
57
Message function for the message-passing scheme,
59
58
returning the message from node `j` to node `i` .
60
- In the message-passing scheme. the incoming messages
59
+ In the message-passing scheme, the incoming messages
61
60
from the neighborhood of `i` will later be aggregated
62
61
in order to [`update`](@ref) the features of node `i`.
63
62
@@ -77,7 +76,7 @@ See also [`update`](@ref) and [`propagate`](@ref).
77
76
function message end
78
77
79
78
"""
80
- update(mp, m̄, x , [u])
79
+ update(mp, x, m̄ , [u])
81
80
82
81
Update function for the message-passing scheme,
83
82
returning a new set of node features `x′` based on old
@@ -98,47 +97,45 @@ See also [`message`](@ref) and [`propagate`](@ref).
98
97
"""
99
98
function update end
100
99
101
-
102
100
_gather (x, i) = NNlib. gather (x, i)
103
101
_gather (x:: Nothing , i) = nothing
104
102
105
103
# # Step 1.
106
104
107
- function compute_batch_message (mp, g, E, X, u )
105
+ function compute_batch_message (mp, g, X, E, U )
108
106
s, t = edge_index (g)
109
107
Xi = _gather (X, t)
110
108
Xj = _gather (X, s)
111
- M = message (mp, Xi, Xj, E, u )
109
+ M = message (mp, Xi, Xj, E, U )
112
110
return M
113
111
end
114
112
115
- # @inline message(mp, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future
116
113
@inline message (mp, x_i, x_j, e_ij, u) = message (mp, x_i, x_j, e_ij)
117
114
@inline message (mp, x_i, x_j, e_ij) = message (mp, x_i, x_j)
118
115
@inline message (mp, x_i, x_j) = x_j
119
116
120
- # # Step 2
121
-
122
- @inline update_edge (mp, M, E, u) = update_edge (mp, M, E)
123
- @inline update_edge (mp, M, E) = E
124
-
125
- # # Step 3
117
+ # # Step 2
126
118
127
- function aggregate_neighbors (mp, aggr, g , E)
119
+ function aggregate_neighbors (mp, g, aggr , E)
128
120
s, t = edge_index (g)
129
121
NNlib. scatter (aggr, E, t)
130
122
end
131
123
132
- aggregate_neighbors (mp, aggr:: Nothing , g, E) = nothing
124
+ aggregate_neighbors (mp, g, aggr:: Nothing , E) = nothing
125
+
126
+ # # Step 3
127
+
128
+ @inline update (mp, x, m̄, u) = update (mp, x, m̄)
129
+ @inline update (mp, x, m̄) = m̄
133
130
134
131
# # Step 4
135
132
136
- # @inline update(mp, i, m̄, x, u) = update(mp, m, x, u)
137
- @inline update (mp, m̄, x, u) = update (mp, m̄, x)
138
- @inline update (mp, m̄, x) = m̄
133
+ @inline update_edge (mp, E, M, U) = update_edge (mp, E, M)
134
+ @inline update_edge (mp, E, M) = E
139
135
140
136
# # Step 5
141
137
142
- @inline update_global (mp, E, X, u) = u
138
+ @inline update_global (mp, U, X, E) = update_global (mp, U, X)
139
+ @inline update_global (mp, U, X) = U
143
140
144
141
# ## end steps ###
0 commit comments