Skip to content

Commit 3b42087

Browse files
[GNNLux] GCNConv, ChebConv, GNNChain (#462)
* add gcconv and chebconv * gnn chain
1 parent 79515e9 commit 3b42087

File tree

8 files changed

+281
-64
lines changed

8 files changed

+281
-64
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
module GNNLux
22
using ConcreteStructs: @concrete
33
using NNlib: NNlib
4-
using LuxCore: LuxCore, AbstractExplicitLayer
5-
using Lux: glorot_uniform, zeros32
4+
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
5+
using Lux: Lux, glorot_uniform, zeros32
66
using Reexport: @reexport
77
using Random: AbstractRNG
88
using GNNlib: GNNlib
99
@reexport using GNNGraphs
1010

11+
include("layers/basic.jl")
12+
export GNNLayer,
13+
GNNContainerLayer,
14+
GNNChain
15+
1116
include("layers/conv.jl")
12-
export GraphConv
17+
export GCNConv,
18+
ChebConv,
19+
GraphConv
1320

1421
end #module
1522

GNNLux/src/layers/basic.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
abstract type GNNLayer <: AbstractExplicitLayer end
3+
4+
An abstract type from which graph neural network layers are derived.
5+
It is Derived from Lux's `AbstractExplicitLayer` type.
6+
7+
See also [`GNNChain`](@ref GNNLux.GNNChain).
8+
"""
9+
abstract type GNNLayer <: AbstractExplicitLayer end
10+
11+
abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end
12+
13+
@concrete struct GNNChain <: GNNContainerLayer{(:layers,)}
14+
layers <: NamedTuple
15+
end
16+
17+
GNNChain(xs...) = GNNChain(; (Symbol("layer_", i) => x for (i, x) in enumerate(xs))...)
18+
19+
function GNNChain(; kw...)
20+
:layers in Base.keys(kw) &&
21+
throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
22+
nt = NamedTuple{keys(kw)}(values(kw))
23+
nt = map(_wrapforchain, nt)
24+
return GNNChain(nt)
25+
end
26+
27+
_wrapforchain(l::AbstractExplicitLayer) = l
28+
_wrapforchain(l) = Lux.WrappedFunction(l)
29+
30+
Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers))
31+
Base.getindex(c::GNNChain, i::Int) = c.layers[i]
32+
Base.getindex(c::GNNChain, i::AbstractVector) = GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))
33+
34+
function Base.getproperty(c::GNNChain, name::Symbol)
35+
hasfield(typeof(c), name) && return getfield(c, name)
36+
layers = getfield(c, :layers)
37+
hasfield(typeof(layers), name) && return getfield(layers, name)
38+
throw(ArgumentError("$(typeof(c)) has no field or layer $name"))
39+
end
40+
41+
Base.length(c::GNNChain) = length(c.layers)
42+
Base.lastindex(c::GNNChain) = lastindex(c.layers)
43+
Base.firstindex(c::GNNChain) = firstindex(c.layers)
44+
45+
LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end])
46+
47+
(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st)
48+
49+
function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times
50+
newst = (;)
51+
for (name, l) in pairs(layers)
52+
x, s′ = _applylayer(l, g, x, getproperty(ps, name), getproperty(st, name))
53+
newst = merge(newst, (; name => s′))
54+
end
55+
return x, newst
56+
end
57+
58+
_applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;)
59+
_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st)
60+
_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
61+
_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)

GNNLux/src/layers/conv.jl

Lines changed: 115 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,132 @@
1+
# Missing Layers
2+
3+
# | Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
4+
# | :-------- | :---: |:---: |:---: | :---: | :---: |
5+
# | [`AGNNConv`](@ref) | | | ✓ | | |
6+
# | [`CGConv`](@ref) | | | ✓ | ✓ | ✓ |
7+
# | [`EGNNConv`](@ref) | | | ✓ | | |
8+
# | [`EdgeConv`](@ref) | | | | ✓ | |
9+
# | [`GATConv`](@ref) | | | ✓ | ✓ | ✓ |
10+
# | [`GATv2Conv`](@ref) | | | ✓ | ✓ | ✓ |
11+
# | [`GatedGraphConv`](@ref) | ✓ | | | | ✓ |
12+
# | [`GINConv`](@ref) | ✓ | | | ✓ | ✓ |
13+
# | [`GMMConv`](@ref) | | | ✓ | | |
14+
# | [`MEGNetConv`](@ref) | | | ✓ | | |
15+
# | [`NNConv`](@ref) | | | ✓ | | |
16+
# | [`ResGatedGraphConv`](@ref) | | | | ✓ | ✓ |
17+
# | [`SAGEConv`](@ref) | ✓ | | | ✓ | ✓ |
18+
# | [`SGConv`](@ref) | ✓ | | | | ✓ |
19+
# | [`TransformerConv`](@ref) | | | ✓ | | |
20+
21+
22+
@concrete struct GCNConv <: GNNLayer
23+
in_dims::Int
24+
out_dims::Int
25+
use_bias::Bool
26+
add_self_loops::Bool
27+
use_edge_weight::Bool
28+
init_weight
29+
init_bias
30+
σ
31+
end
132

2-
@doc raw"""
3-
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
33+
function GCNConv(ch::Pair{Int, Int}, σ = identity;
34+
init_weight = glorot_uniform,
35+
init_bias = zeros32,
36+
use_bias::Bool = true,
37+
add_self_loops::Bool = true,
38+
use_edge_weight::Bool = false,
39+
allow_fast_activation::Bool = true)
40+
in_dims, out_dims = ch
41+
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
42+
return GCNConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ)
43+
end
444

5-
Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
45+
function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv)
46+
weight = l.init_weight(rng, l.out_dims, l.in_dims)
47+
if l.use_bias
48+
bias = l.init_bias(rng, l.out_dims)
49+
return (; weight, bias)
50+
else
51+
return (; weight)
52+
end
53+
end
654

7-
Performs:
8-
```math
9-
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j
10-
```
55+
LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
56+
LuxCore.statelength(d::GCNConv) = 0
57+
LuxCore.outputsize(d::GCNConv) = (d.out_dims,)
1158

12-
where the aggregation type is selected by `aggr`.
59+
function Base.show(io::IO, l::GCNConv)
60+
print(io, "GCNConv(", l.in_dims, " => ", l.out_dims)
61+
l.σ == identity || print(io, ", ", l.σ)
62+
l.use_bias || print(io, ", use_bias=false")
63+
l.add_self_loops || print(io, ", add_self_loops=false")
64+
!l.use_edge_weight || print(io, ", use_edge_weight=true")
65+
print(io, ")")
66+
end
1367

14-
# Arguments
68+
# TODO norm_fn should be keyword argument only
69+
(l::GCNConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing, norm_fn= d -> 1 ./ sqrt.(d)) =
70+
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
71+
(l::GCNConv)(g, x, edge_weight, ps, st; conv_weight=nothing, norm_fn = d -> 1 ./ sqrt.(d)) =
72+
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
73+
(l::GCNConv)(g, x, edge_weight, norm_fn, ps, st; conv_weight=nothing) =
74+
GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight, ps), st
1575

16-
- `in`: The dimension of input features.
17-
- `out`: The dimension of output features.
18-
- `σ`: Activation function.
19-
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
20-
- `bias`: Add learnable bias.
21-
- `init`: Weights' initializer.
76+
@concrete struct ChebConv <: GNNLayer
77+
in_dims::Int
78+
out_dims::Int
79+
use_bias::Bool
80+
k::Int
81+
init_weight
82+
init_bias
83+
σ
84+
end
2285

23-
# Examples
86+
function ChebConv(ch::Pair{Int, Int}, k::Int, σ = identity;
87+
init_weight = glorot_uniform,
88+
init_bias = zeros32,
89+
use_bias::Bool = true,
90+
allow_fast_activation::Bool = true)
91+
in_dims, out_dims = ch
92+
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
93+
return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias, σ)
94+
end
2495

25-
```julia
26-
# create data
27-
s = [1,1,2,3]
28-
t = [2,3,1,1]
29-
in_channel = 3
30-
out_channel = 5
31-
g = GNNGraph(s, t)
32-
x = randn(Float32, 3, g.num_nodes)
96+
function LuxCore.initialparameters(rng::AbstractRNG, l::ChebConv)
97+
weight = l.init_weight(rng, l.out_dims, l.in_dims, l.k)
98+
if l.use_bias
99+
bias = l.init_bias(rng, l.out_dims)
100+
return (; weight, bias)
101+
else
102+
return (; weight)
103+
end
104+
end
105+
106+
LuxCore.parameterlength(l::ChebConv) = l.use_bias ? l.in_dims * l.out_dims * l.k + l.out_dims :
107+
l.in_dims * l.out_dims * l.k
108+
LuxCore.statelength(d::ChebConv) = 0
109+
LuxCore.outputsize(d::ChebConv) = (d.out_dims,)
110+
111+
function Base.show(io::IO, l::ChebConv)
112+
print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", K=", l.K)
113+
l.σ == identity || print(io, ", ", l.σ)
114+
l.use_bias || print(io, ", use_bias=false")
115+
print(io, ")")
116+
end
33117

34-
# create layer
35-
l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean)
118+
(l::ChebConv)(g, x, ps, st) = GNNlib.cheb_conv(l, g, x, ps), st
36119

37-
# forward pass
38-
y = l(g, x)
39-
```
40-
"""
41-
@concrete struct GraphConv <: AbstractExplicitLayer
120+
@concrete struct GraphConv <: GNNLayer
42121
in_dims::Int
43122
out_dims::Int
44123
use_bias::Bool
45-
init_weight::Function
46-
init_bias::Function
124+
init_weight
125+
init_bias
47126
σ
48127
aggr
49128
end
50129

51-
52130
function GraphConv(ch::Pair{Int, Int}, σ = identity;
53131
aggr = +,
54132
init_weight = glorot_uniform,
@@ -65,10 +143,10 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv)
65143
weight2 = l.init_weight(rng, l.out_dims, l.in_dims)
66144
if l.use_bias
67145
bias = l.init_bias(rng, l.out_dims)
146+
return (; weight1, weight2, bias)
68147
else
69-
bias = false
148+
return (; weight1, weight2)
70149
end
71-
return (; weight1, weight2, bias)
72150
end
73151

74152
function LuxCore.parameterlength(l::GraphConv)
@@ -90,4 +168,4 @@ function Base.show(io::IO, l::GraphConv)
90168
print(io, ")")
91169
end
92170

93-
(l::GraphConv)(g::GNNGraph, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
171+
(l::GraphConv)(g, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st

GNNLux/test/layers/basic_tests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
@testitem "layers/basic" setup=[SharedTestSetup] begin
2+
rng = StableRNG(17)
3+
g = rand_graph(10, 40, seed=17)
4+
x = randn(rng, Float32, 3, 10)
5+
6+
@testset "GNNLayer" begin
7+
@test GNNLayer <: LuxCore.AbstractExplicitLayer
8+
end
9+
10+
@testset "GNNChain" begin
11+
@test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)}
12+
@test GNNChain <: GNNContainerLayer
13+
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3))
14+
ps = LuxCore.initialparameters(rng, c)
15+
st = LuxCore.initialstates(rng, c)
16+
@test LuxCore.parameterlength(c) == LuxCore.parameterlength(ps)
17+
@test LuxCore.statelength(c) == LuxCore.statelength(st)
18+
y, st′ = c(g, x, ps, st)
19+
@test LuxCore.outputsize(c) == (3,)
20+
@test size(y) == (3, 10)
21+
loss = (x, ps) -> sum(first(c(g, x, ps, st)))
22+
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
23+
end
24+
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,41 @@
11
@testitem "layers/conv" setup=[SharedTestSetup] begin
22
rng = StableRNG(1234)
3-
g = rand_graph(10, 30, seed=1234)
3+
g = rand_graph(10, 40, seed=1234)
44
x = randn(rng, Float32, 3, 10)
55

6+
@testset "GCNConv" begin
7+
l = GCNConv(3 => 5, relu)
8+
@test l isa GNNLayer
9+
ps = Lux.initialparameters(rng, l)
10+
st = Lux.initialstates(rng, l)
11+
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
12+
@test Lux.statelength(l) == Lux.statelength(st)
13+
14+
y, _ = l(g, x, ps, st)
15+
@test Lux.outputsize(l) == (5,)
16+
@test size(y) == (5, 10)
17+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
18+
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
19+
end
20+
21+
@testset "ChebConv" begin
22+
l = ChebConv(3 => 5, 2, relu)
23+
@test l isa GNNLayer
24+
ps = Lux.initialparameters(rng, l)
25+
st = Lux.initialstates(rng, l)
26+
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
27+
@test Lux.statelength(l) == Lux.statelength(st)
28+
29+
y, _ = l(g, x, ps, st)
30+
@test Lux.outputsize(l) == (5,)
31+
@test size(y) == (5, 10)
32+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
33+
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
34+
end
35+
636
@testset "GraphConv" begin
737
l = GraphConv(3 => 5, relu)
38+
@test l isa GNNLayer
839
ps = Lux.initialparameters(rng, l)
940
st = Lux.initialstates(rng, l)
1041
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@@ -14,6 +45,6 @@
1445
@test Lux.outputsize(l) == (5,)
1546
@test size(y) == (5, 10)
1647
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
17-
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3
48+
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
1849
end
1950
end

0 commit comments

Comments
 (0)