Skip to content

Commit c2c6cfe

Browse files
remove MessagePassing type, add GNNLayer
1 parent 65de740 commit c2c6cfe

File tree

10 files changed

+55
-100
lines changed

10 files changed

+55
-100
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ export
2727
adjacency_matrix,
2828

2929
# layers/msgpass
30-
MessagePassing,
30+
31+
# layers/basic
32+
GNNLayer,
33+
GNNChain,
3134

3235
# layers/conv
3336
GCNConv,
@@ -38,33 +41,23 @@ export
3841
EdgeConv,
3942
GINConv,
4043

41-
# layer/pool
44+
# layers/pool
4245
GlobalPool,
4346
LocalPool,
4447
TopKPool,
45-
topk_index,
48+
topk_index
4649

47-
# models
48-
GAE,
49-
VGAE,
50-
InnerProductDecoder,
51-
VariationalEncoder,
52-
summarize,
53-
sample,
5450

55-
# layer/selector
56-
bypass_graph
5751

5852

5953
include("gnngraph.jl")
6054
include("graph_conversions.jl")
6155
include("utils.jl")
6256

6357
include("layers/msgpass.jl")
64-
58+
include("layers/basic.jl")
6559
include("layers/conv.jl")
6660
include("layers/pool.jl")
67-
include("layers/misc.jl")
6861

6962

7063
end

src/layers/basic.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""
2+
abstract type GNNLayer end
3+
4+
An abstract type from which graph neural network layers are derived.
5+
6+
See also [`GNNChain`](@ref).
7+
"""
8+
abstract type GNNLayer end

src/layers/conv.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ of size `(num_features, num_nodes)`.
2020
- `bias`: Add learnable bias.
2121
- `init`: Weights' initializer.
2222
"""
23-
struct GCNConv{A<:AbstractMatrix, B, F} <: MessagePassing
23+
struct GCNConv{A<:AbstractMatrix, B, F} <: GNNLayer
2424
weight::A
2525
bias::B
2626
σ::F
@@ -97,7 +97,7 @@ with ``\hat{L}`` the [`scaled_laplacian`](@ref).
9797
- `bias`: Add learnable bias.
9898
- `init`: Weights' initializer.
9999
"""
100-
struct ChebConv{A<:AbstractArray{<:Number,3}, B}
100+
struct ChebConv{A<:AbstractArray{<:Number,3}, B} <: GNNLayer
101101
weight::A
102102
bias::B
103103
k::Int
@@ -160,7 +160,7 @@ where the aggregation type is selected by `aggr`.
160160
- `bias`: Add learnable bias.
161161
- `init`: Weights' initializer.
162162
"""
163-
struct GraphConv{A<:AbstractMatrix, B} <: MessagePassing
163+
struct GraphConv{A<:AbstractMatrix, B} <: GNNLayer
164164
weight1::A
165165
weight2::A
166166
bias::B
@@ -229,7 +229,7 @@ with ``z_i`` a normalization factor.
229229
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
230230
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
231231
"""
232-
struct GATConv{T, A<:AbstractMatrix{T}, B} <: MessagePassing
232+
struct GATConv{T, A<:AbstractMatrix{T}, B} <: GNNLayer
233233
weight::A
234234
bias::B
235235
a::A
@@ -313,7 +313,7 @@ Implements the recursion
313313
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
314314
- `init`: Weight initialization function.
315315
"""
316-
struct GatedGraphConv{A<:AbstractArray{<:Number,3}, R} <: MessagePassing
316+
struct GatedGraphConv{A<:AbstractArray{<:Number,3}, R} <: GNNLayer
317317
weight::A
318318
gru::R
319319
out_ch::Int
@@ -376,7 +376,7 @@ where `f` typically denotes a learnable function, e.g. a linear layer or a multi
376376
- `f`: A (possibly learnable) function acting on edge features.
377377
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
378378
"""
379-
struct EdgeConv <: MessagePassing
379+
struct EdgeConv <: GNNLayer
380380
nn
381381
aggr
382382
end
@@ -420,7 +420,7 @@ where `f` typically denotes a learnable function, e.g. a linear layer or a multi
420420
- `f`: A (possibly learnable) function acting on node features.
421421
- `eps`: Weighting factor.
422422
"""
423-
struct GINConv{R<:Real} <: MessagePassing
423+
struct GINConv{R<:Real} <: GNNLayer
424424
nn
425425
eps::R
426426
end

src/layers/misc.jl

Lines changed: 0 additions & 13 deletions
This file was deleted.

src/layers/msgpass.jl

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,9 @@
11
# Adapted message passing from paper
22
# "Relational inductive biases, deep learning, and graph networks"
3-
"""
4-
MessagePassing
5-
6-
The abstract type from which all message passing layers are derived.
7-
8-
Related methods are [`propagate`](@ref), [`message`](@ref),
9-
[`update`](@ref), [`update_edge`](@ref), and [`update_global`](@ref).
10-
"""
11-
abstract type MessagePassing end
123

134
"""
14-
propagate(mp::MessagePassing, g::GNNGraph, aggr)
15-
propagate(mp::MessagePassing, g::GNNGraph, E, X, u, aggr)
5+
propagate(mp, g::GNNGraph, aggr)
6+
propagate(mp, g::GNNGraph, E, X, u, aggr)
167
178
Perform the sequence of operation implementing the message-passing scheme
189
and updating node, edge, and global features `X`, `E`, and `u` respectively.
@@ -27,9 +18,8 @@ X = update(mp, M̄, X, u)
2718
u = update_global(mp, E, X, u)
2819
```
2920
30-
Custom layers sub-typing [`MessagePassing`](@ref)
31-
typically call define their own [`update`](@ref)
32-
and [`message`](@ref) function, than call
21+
Custom layers typically define their own [`update`](@ref)
22+
and [`message`](@ref) function, then call
3323
this method in the forward pass:
3424
3525
```julia
@@ -45,14 +35,14 @@ See also [`message`](@ref) and [`update`](@ref).
4535
"""
4636
function propagate end
4737

48-
function propagate(mp::MessagePassing, g::GNNGraph, aggr)
38+
function propagate(mp, g::GNNGraph, aggr)
4939
E, X, u = propagate(mp, g,
5040
edge_feature(g), node_feature(g), global_feature(g),
5141
aggr)
5242
GNNGraph(g, nf=X, ef=E, gf=u)
5343
end
5444

55-
function propagate(mp::MessagePassing, g::GNNGraph, E, X, u, aggr)
45+
function propagate(mp, g::GNNGraph, E, X, u, aggr)
5646
M = compute_batch_message(mp, g, E, X, u)
5747
E = update_edge(mp, M, E, u)
5848
= aggregate_neighbors(mp, aggr, g, M)
@@ -62,7 +52,7 @@ function propagate(mp::MessagePassing, g::GNNGraph, E, X, u, aggr)
6252
end
6353

6454
"""
65-
message(mp::MessagePassing, x_i, x_j, [e_ij, u])
55+
message(mp, x_i, x_j, [e_ij, u])
6656
6757
Message function for the message-passing scheme,
6858
returning the message from node `j` to node `i` .
@@ -71,12 +61,11 @@ from the neighborhood of `i` will later be aggregated
7161
in order to [`update`](@ref) the features of node `i`.
7262
7363
By default, the function returns `x_j`.
74-
Layers subtyping [`MessagePassing`](@ref) should
75-
specialize this method with custom behavior.
64+
Custom layer should specialize this method with the desired behavior.
7665
7766
## Arguments
7867
79-
- `mp`: A [`MessagePassing`](@ref) layer.
68+
- `mp`: A gnn layer.
8069
- `x_i`: Features of the central node `i`.
8170
- `x_j`: Features of the neighbor `j` of node `i`.
8271
- `e_ij`: Features of edge (`i`, `j`).
@@ -87,20 +76,19 @@ See also [`update`](@ref) and [`propagate`](@ref).
8776
function message end
8877

8978
"""
90-
update(mp::MessagePassing, m̄, x, [u])
79+
update(mp, m̄, x, [u])
9180
9281
Update function for the message-passing scheme,
9382
returning a new set of node features `x′` based on old
9483
features `x` and the incoming message from the neighborhood
9584
aggregation `m̄`.
9685
9786
By default, the function returns `m̄`.
98-
Layers subtyping [`MessagePassing`](@ref) should
99-
specialize this method with custom behavior.
87+
Custom layers should specialize this method with the desired behavior.
10088
10189
## Arguments
10290
103-
- `mp`: A [`MessagePassing`](@ref) layer.
91+
- `mp`: A gnn layer.
10492
- `m̄`: Aggregated edge messages from the [`message`](@ref) function.
10593
- `x`: Node features to be updated.
10694
- `u`: Global features.
@@ -115,41 +103,41 @@ _gather(x::Nothing, i) = nothing
115103

116104
## Step 1.
117105

118-
function compute_batch_message(mp::MessagePassing, g, E, X, u)
106+
function compute_batch_message(mp, g, E, X, u)
119107
s, t = edge_index(g)
120108
Xi = _gather(X, t)
121109
Xj = _gather(X, s)
122110
M = message(mp, Xi, Xj, E, u)
123111
return M
124112
end
125113

126-
# @inline message(mp::MessagePassing, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future
127-
@inline message(mp::MessagePassing, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij)
128-
@inline message(mp::MessagePassing, x_i, x_j, e_ij) = message(mp, x_i, x_j)
129-
@inline message(mp::MessagePassing, x_i, x_j) = x_j
114+
# @inline message(mp, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future
115+
@inline message(mp, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij)
116+
@inline message(mp, x_i, x_j, e_ij) = message(mp, x_i, x_j)
117+
@inline message(mp, x_i, x_j) = x_j
130118

131119
## Step 2
132120

133-
@inline update_edge(mp::MessagePassing, M, E, u) = update_edge(mp, M, E)
134-
@inline update_edge(mp::MessagePassing, M, E) = E
121+
@inline update_edge(mp, M, E, u) = update_edge(mp, M, E)
122+
@inline update_edge(mp, M, E) = E
135123

136124
## Step 3
137125

138-
function aggregate_neighbors(mp::MessagePassing, aggr, g, E)
126+
function aggregate_neighbors(mp, aggr, g, E)
139127
s, t = edge_index(g)
140128
NNlib.scatter(aggr, E, t)
141129
end
142130

143-
aggregate_neighbors(mp::MessagePassing, aggr::Nothing, g, E) = nothing
131+
aggregate_neighbors(mp, aggr::Nothing, g, E) = nothing
144132

145133
## Step 4
146134

147-
# @inline update(mp::MessagePassing, i, m̄, x, u) = update(mp, m, x, u)
148-
@inline update(mp::MessagePassing, m̄, x, u) = update(mp, m̄, x)
149-
@inline update(mp::MessagePassing, m̄, x) =
135+
# @inline update(mp, i, m̄, x, u) = update(mp, m, x, u)
136+
@inline update(mp, m̄, x, u) = update(mp, m̄, x)
137+
@inline update(mp, m̄, x) =
150138

151139
## Step 5
152140

153-
@inline update_global(mp::MessagePassing, E, X, u) = u
141+
@inline update_global(mp, E, X, u) = u
154142

155143
### end steps ###

test/cuda/layers/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ adj = [0 1 0 0 0 0
99
0 1 0 1 0 1
1010
0 1 1 0 1 0]
1111

12-
struct NewCudaLayer <: MessagePassing
12+
struct NewCudaLayer
1313
weight
1414
end
1515
NewCudaLayer(m, n) = NewCudaLayer(randn(T, m,n))

test/layers/basic.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
@testset "basic" begin
2+
3+
end

test/layers/misc.jl

Lines changed: 0 additions & 24 deletions
This file was deleted.

test/layers/msgpass.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "MessagePassing" begin
1+
@testset "message passing" begin
22
in_channel = 10
33
out_channel = 5
44
num_V = 6
@@ -12,7 +12,7 @@
1212
0 1 0 1 0 1
1313
0 1 1 0 1 0]
1414

15-
struct NewLayer{G} <: MessagePassing end
15+
struct NewLayer{G} end
1616

1717
X = rand(T, in_channel, num_V)
1818
E = rand(T, in_channel, num_E)
@@ -90,7 +90,7 @@
9090
@test size(global_feature(fg_)) == (in_channel,)
9191
end
9292

93-
struct NewLayerW{G} <: MessagePassing
93+
struct NewLayerW{G}
9494
weight
9595
end
9696

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ include("cuda/test_utils.jl")
1616
tests = [
1717
"gnngraph",
1818
"layers/msgpass",
19+
"layers/basic",
1920
"layers/conv",
2021
"layers/pool",
21-
"layers/misc",
2222
]
2323

2424
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")

0 commit comments

Comments
 (0)