Skip to content

Commit 4b4477e

Browse files
[GNNLux] more layers (#469)
* layers * fixes
1 parent fc67808 commit 4b4477e

File tree

5 files changed

+341
-33
lines changed

5 files changed

+341
-33
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module GNNLux
22
using ConcreteStructs: @concrete
3-
using NNlib: NNlib, sigmoid, relu
3+
using NNlib: NNlib, sigmoid, relu, swish
44
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
5-
using Lux: Lux, Dense, glorot_uniform, zeros32, StatefulLuxLayer
5+
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
66
using Reexport: @reexport
77
using Random: AbstractRNG
88
using GNNlib: GNNlib
@@ -18,10 +18,10 @@ export AGNNConv,
1818
CGConv,
1919
ChebConv,
2020
EdgeConv,
21-
# EGNNConv,
22-
# DConv,
23-
# GATConv,
24-
# GATv2Conv,
21+
EGNNConv,
22+
DConv,
23+
GATConv,
24+
GATv2Conv,
2525
# GatedGraphConv,
2626
GCNConv,
2727
# GINConv,

GNNLux/src/layers/conv.jl

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,264 @@ function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
255255
end
256256

257257

258+
@concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)}
259+
ϕe
260+
ϕx
261+
ϕh
262+
num_features
263+
residual::Bool
264+
end
265+
266+
function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false)
267+
return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual)
268+
end
269+
270+
#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
271+
function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1],
272+
residual = false)
273+
(in_size, edge_feat_size), out_size = ch
274+
act_fn = swish
275+
276+
# +1 for the radial feature: ||x_i - x_j||^2
277+
ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn),
278+
Dense(hidden_size => hidden_size, act_fn))
279+
280+
ϕh = Chain(Dense(in_size + hidden_size => hidden_size, swish),
281+
Dense(hidden_size => out_size))
282+
283+
ϕx = Chain(Dense(hidden_size => hidden_size, swish),
284+
Dense(hidden_size => 1, use_bias = false))
285+
286+
num_features = (in = in_size, edge = edge_feat_size, out = out_size,
287+
hidden = hidden_size)
288+
if residual
289+
@assert in_size==out_size "Residual connection only possible if in_size == out_size"
290+
end
291+
return EGNNConv(ϕe, ϕx, ϕh, num_features, residual)
292+
end
293+
294+
LuxCore.outputsize(l::EGNNConv) = (l.num_features.out,)
295+
296+
(l::EGNNConv)(g, h, x, ps, st) = l(g, h, x, nothing, ps, st)
297+
298+
function (l::EGNNConv)(g, h, x, e, ps, st)
299+
ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))
300+
ϕx = StatefulLuxLayer{true}(l.ϕx, ps.ϕx, _getstate(st, :ϕx))
301+
ϕh = StatefulLuxLayer{true}(l.ϕh, ps.ϕh, _getstate(st, :ϕh))
302+
m = (; ϕe, ϕx, ϕh, l.residual, l.num_features)
303+
return GNNlib.egnn_conv(m, g, h, x, e), st
304+
end
305+
306+
function Base.show(io::IO, l::EGNNConv)
307+
ne = l.num_features.edge
308+
nin = l.num_features.in
309+
nout = l.num_features.out
310+
nh = l.num_features.hidden
311+
print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh")
312+
if l.residual
313+
print(io, ", residual=true")
314+
end
315+
print(io, ")")
316+
end
317+
318+
@concrete struct DConv <: GNNLayer
319+
in_dims::Int
320+
out_dims::Int
321+
k::Int
322+
init_weight
323+
init_bias
324+
use_bias::Bool
325+
end
326+
327+
function DConv(ch::Pair{Int, Int}, k::Int;
328+
init_weight = glorot_uniform,
329+
init_bias = zeros32,
330+
use_bias = true)
331+
in, out = ch
332+
return DConv(in, out, k, init_weight, init_bias, use_bias)
333+
end
334+
335+
function LuxCore.initialparameters(rng::AbstractRNG, l::DConv)
336+
weights = l.init_weight(rng, 2, l.k, l.out_dims, l.in_dims)
337+
if l.use_bias
338+
bias = l.init_bias(rng, l.out_dims)
339+
return (; weights, bias)
340+
else
341+
return (; weights)
342+
end
343+
end
344+
345+
LuxCore.outputsize(l::DConv) = (l.out_dims,)
346+
LuxCore.parameterlength(l::DConv) = l.use_bias ? 2 * l.in_dims * l.out_dims * l.k + l.out_dims :
347+
2 * l.in_dims * l.out_dims * l.k
348+
349+
function (l::DConv)(g, x, ps, st)
350+
m = (; ps.weights, bias = _getbias(ps), l.k)
351+
return GNNlib.d_conv(m, g, x), st
352+
end
353+
354+
function Base.show(io::IO, l::DConv)
355+
print(io, "DConv($(l.in_dims) => $(l.out_dims), k=$(l.k))")
356+
end
357+
358+
@concrete struct GATConv <: GNNLayer
359+
dense_x
360+
dense_e
361+
init_weight
362+
init_bias
363+
use_bias::Bool
364+
σ
365+
negative_slope
366+
channel::Pair{NTuple{2, Int}, Int}
367+
heads::Int
368+
concat::Bool
369+
add_self_loops::Bool
370+
dropout
371+
end
372+
373+
374+
GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...)
375+
376+
function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
377+
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
378+
init_weight = glorot_uniform, init_bias = zeros32,
379+
use_bias::Bool = true,
380+
add_self_loops = true, dropout=0.0)
381+
(in, ein), out = ch
382+
if add_self_loops
383+
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
384+
end
385+
386+
dense_x = Dense(in => out * heads, use_bias = false)
387+
dense_e = ein > 0 ? Dense(ein => out * heads, use_bias = false) : nothing
388+
negative_slope = convert(Float32, negative_slope)
389+
return GATConv(dense_x, dense_e, init_weight, init_bias, use_bias,
390+
σ, negative_slope, ch, heads, concat, add_self_loops, dropout)
391+
end
392+
393+
LuxCore.outputsize(l::GATConv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],)
394+
##TODO: parameterlength
395+
396+
function LuxCore.initialparameters(rng::AbstractRNG, l::GATConv)
397+
(in, ein), out = l.channel
398+
dense_x = LuxCore.initialparameters(rng, l.dense_x)
399+
a = l.init_weight(ein > 0 ? 3out : 2out, l.heads)
400+
ps = (; dense_x, a)
401+
if ein > 0
402+
ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e))
403+
end
404+
if l.use_bias
405+
ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out))
406+
end
407+
return ps
408+
end
409+
410+
(l::GATConv)(g, x, ps, st) = l(g, x, nothing, ps, st)
411+
412+
function (l::GATConv)(g, x, e, ps, st)
413+
dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x))
414+
dense_e = l.dense_e === nothing ? nothing :
415+
StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e))
416+
417+
m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ,
418+
ps.a, bias = _getbias(ps), dense_x, dense_e, l.negative_slope)
419+
return GNNlib.gat_conv(m, g, x, e), st
420+
end
421+
422+
function Base.show(io::IO, l::GATConv)
423+
(in, ein), out = l.channel
424+
print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads)
425+
l.σ == identity || print(io, ", ", l.σ)
426+
print(io, ", negative_slope=", l.negative_slope)
427+
print(io, ")")
428+
end
429+
430+
@concrete struct GATv2Conv <: GNNLayer
431+
dense_i
432+
dense_j
433+
dense_e
434+
init_weight
435+
init_bias
436+
use_bias::Bool
437+
σ
438+
negative_slope
439+
channel::Pair{NTuple{2, Int}, Int}
440+
heads::Int
441+
concat::Bool
442+
add_self_loops::Bool
443+
dropout
444+
end
445+
446+
function GATv2Conv(ch::Pair{Int, Int}, args...; kws...)
447+
GATv2Conv((ch[1], 0) => ch[2], args...; kws...)
448+
end
449+
450+
function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
451+
σ = identity;
452+
heads::Int = 1,
453+
concat::Bool = true,
454+
negative_slope = 0.2,
455+
init_weight = glorot_uniform,
456+
init_bias = zeros32,
457+
use_bias::Bool = true,
458+
add_self_loops = true,
459+
dropout=0.0)
460+
461+
(in, ein), out = ch
462+
463+
if add_self_loops
464+
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
465+
end
466+
467+
dense_i = Dense(in => out * heads; use_bias, init_weight, init_bias)
468+
dense_j = Dense(in => out * heads; use_bias = false, init_weight)
469+
if ein > 0
470+
dense_e = Dense(ein => out * heads; use_bias = false, init_weight)
471+
else
472+
dense_e = nothing
473+
end
474+
return GATv2Conv(dense_i, dense_j, dense_e,
475+
init_weight, init_bias, use_bias,
476+
σ, negative_slope,
477+
ch, heads, concat, add_self_loops, dropout)
478+
end
479+
480+
481+
LuxCore.outputsize(l::GATv2Conv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],)
482+
##TODO: parameterlength
483+
484+
function LuxCore.initialparameters(rng::AbstractRNG, l::GATv2Conv)
485+
(in, ein), out = l.channel
486+
dense_i = LuxCore.initialparameters(rng, l.dense_i)
487+
dense_j = LuxCore.initialparameters(rng, l.dense_j)
488+
a = l.init_weight(out, l.heads)
489+
ps = (; dense_i, dense_j, a)
490+
if ein > 0
491+
ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e))
492+
end
493+
if l.use_bias
494+
ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out))
495+
end
496+
return ps
497+
end
498+
499+
(l::GATv2Conv)(g, x, ps, st) = l(g, x, nothing, ps, st)
500+
501+
function (l::GATv2Conv)(g, x, e, ps, st)
502+
dense_i = StatefulLuxLayer{true}(l.dense_i, ps.dense_i, _getstate(st, :dense_i))
503+
dense_j = StatefulLuxLayer{true}(l.dense_j, ps.dense_j, _getstate(st, :dense_j))
504+
dense_e = l.dense_e === nothing ? nothing :
505+
StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e))
506+
507+
m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ,
508+
ps.a, bias = _getbias(ps), dense_i, dense_j, dense_e, l.negative_slope)
509+
return GNNlib.gatv2_conv(m, g, x, e), st
510+
end
511+
512+
function Base.show(io::IO, l::GATv2Conv)
513+
(in, ein), out = l.channel
514+
print(io, "GATv2Conv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads)
515+
l.σ == identity || print(io, ", ", l.σ)
516+
print(io, ", negative_slope=", l.negative_slope)
517+
print(io, ")")
518+
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,81 @@
11
@testitem "layers/conv" setup=[SharedTestSetup] begin
22
rng = StableRNG(1234)
33
g = rand_graph(10, 40, seed=1234)
4-
x = randn(rng, Float32, 3, 10)
4+
in_dims = 3
5+
out_dims = 5
6+
x = randn(rng, Float32, in_dims, 10)
57

68
@testset "GCNConv" begin
7-
l = GCNConv(3 => 5, relu)
8-
test_lux_layer(rng, l, g, x, outputsize=(5,))
9+
l = GCNConv(in_dims => out_dims, relu)
10+
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
911
end
1012

1113
@testset "ChebConv" begin
12-
l = ChebConv(3 => 5, 2)
13-
test_lux_layer(rng, l, g, x, outputsize=(5,))
14+
l = ChebConv(in_dims => out_dims, 2)
15+
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
1416
end
1517

1618
@testset "GraphConv" begin
17-
l = GraphConv(3 => 5, relu)
18-
test_lux_layer(rng, l, g, x, outputsize=(5,))
19+
l = GraphConv(in_dims => out_dims, relu)
20+
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
1921
end
2022

2123
@testset "AGNNConv" begin
2224
l = AGNNConv(init_beta=1.0f0)
23-
test_lux_layer(rng, l, g, x, sizey=(3,10))
25+
test_lux_layer(rng, l, g, x, sizey=(in_dims, 10))
2426
end
2527

2628
@testset "EdgeConv" begin
27-
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
29+
nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims))
2830
l = EdgeConv(nn, aggr = +)
29-
test_lux_layer(rng, l, g, x, sizey=(5,10), container=true)
31+
test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true)
3032
end
3133

3234
@testset "CGConv" begin
33-
l = CGConv(3 => 3, residual = true)
34-
test_lux_layer(rng, l, g, x, outputsize=(3,), container=true)
35+
l = CGConv(in_dims => in_dims, residual = true)
36+
test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true)
37+
end
38+
39+
@testset "DConv" begin
40+
l = DConv(in_dims => out_dims, 2)
41+
test_lux_layer(rng, l, g, x, outputsize=(5,))
42+
end
43+
44+
@testset "EGNNConv" begin
45+
hin = 6
46+
hout = 7
47+
hidden = 8
48+
l = EGNNConv(hin => hout, hidden)
49+
ps = LuxCore.initialparameters(rng, l)
50+
st = LuxCore.initialstates(rng, l)
51+
h = randn(rng, Float32, hin, g.num_nodes)
52+
(hnew, xnew), stnew = l(g, h, x, ps, st)
53+
@test size(hnew) == (hout, g.num_nodes)
54+
@test size(xnew) == (in_dims, g.num_nodes)
55+
end
56+
57+
@testset "GATConv" begin
58+
x = randn(rng, Float32, 6, 10)
59+
60+
l = GATConv(6 => 8, heads=2)
61+
test_lux_layer(rng, l, g, x, outputsize=(16,))
62+
63+
l = GATConv(6 => 8, heads=2, concat=false, dropout=0.5)
64+
test_lux_layer(rng, l, g, x, outputsize=(8,))
65+
66+
#TODO test edge
67+
end
68+
69+
@testset "GATv2Conv" begin
70+
x = randn(rng, Float32, 6, 10)
71+
72+
l = GATv2Conv(6 => 8, heads=2)
73+
test_lux_layer(rng, l, g, x, outputsize=(16,))
74+
75+
l = GATv2Conv(6 => 8, heads=2, concat=false, dropout=0.5)
76+
test_lux_layer(rng, l, g, x, outputsize=(8,))
77+
78+
#TODO test edge
3579
end
3680
end
81+

0 commit comments

Comments
 (0)