Skip to content

Commit d57fb34

Browse files
start implementing GNNChain
1 parent c2c6cfe commit d57fb34

File tree

4 files changed

+67
-1
lines changed

4 files changed

+67
-1
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1212
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1415
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1516
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import KrylovKit
99
using CUDA
1010
using Flux
1111
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
12+
using MacroTools: @forward
1213
using NNlib, NNlibCUDA
1314
using ChainRulesCore
1415
import LightGraphs

src/layers/basic.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,67 @@ An abstract type from which graph neural network layers are derived.
66
See also [`GNNChain`](@ref).
77
"""
88
abstract type GNNLayer end
9+
10+
"""
11+
GNNChain(layers...)
12+
GNNChain(name = layer, ...)
13+
14+
Collects multiple layers / functions to be called in sequence
15+
on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`,
16+
and if names are given, `m[:name] == m[1]` etc.
17+
18+
## Examples
19+
20+
```jldoctest
21+
julia> m = GNNChain(x -> x^2, x -> x+1);
22+
23+
julia> m(5) == 26
24+
true
25+
26+
julia> m = GNNChain(Dense(10, 5, tanh), Dense(5, 2));
27+
28+
julia> x = rand(10, 32);
29+
30+
julia> m(x) == m[2](m[1](x))
31+
true
32+
33+
julia> m2 = GNNChain(enc = GNNChain(Flux.flatten, Dense(10, 5, tanh)),
34+
dec = Dense(5, 2));
35+
36+
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
37+
true
38+
```
39+
"""
40+
struct GNNChain{T}
41+
layers::T
42+
43+
GNNChain(xs...) = new{typeof(xs)}(xs)
44+
45+
function GNNChain(; kw...)
46+
:layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
47+
isempty(kw) && return new{Tuple{}}(())
48+
new{typeof(values(kw))}(values(kw))
49+
end
50+
end
51+
52+
@forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last,
53+
Base.iterate, Base.lastindex, Base.keys
54+
55+
functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
56+
57+
applychain(::Tuple{}, x) = x
58+
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
59+
60+
(c::GNNChain)(x) = applychain(Tuple(c.layers), x)
61+
62+
Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]...)
63+
Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) =
64+
GNNChain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...)
65+
66+
function Base.show(io::IO, c::GNNChain)
67+
print(io, "GNNChain(")
68+
_show_layers(io, c.layers)
69+
print(io, ")")
70+
end
71+
_show_layers(io, layers::Tuple) = join(io, layers, ", ")
72+
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ")

src/layers/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ and [`message`](@ref) function, then call
2323
this method in the forward pass:
2424
2525
```julia
26-
function (l::GNNLayer)(g, X)
26+
function (l::MyLayer)(g, X)
2727
... some prepocessing if needed ...
2828
E = nothing
2929
u = nothing

0 commit comments

Comments
 (0)