Skip to content

Commit 83b6b7e

Browse files
[GNNLux] more layers pt. 3 (#471)
* more layer more layers stuff * fixes
1 parent 0a23ffa commit 83b6b7e

File tree

6 files changed

+108
-39
lines changed

6 files changed

+108
-39
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
module GNNLux
22
using ConcreteStructs: @concrete
33
using NNlib: NNlib, sigmoid, relu, swish
4-
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
5-
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
4+
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
5+
initialparameters, initialstates, parameterlength, statelength
6+
using Lux: Lux, Chain, Dense, GRUCell,
7+
glorot_uniform, zeros32,
8+
StatefulLuxLayer
69
using Reexport: @reexport
710
using Random: AbstractRNG
811
using GNNlib: GNNlib
@@ -22,9 +25,9 @@ export AGNNConv,
2225
DConv,
2326
GATConv,
2427
GATv2Conv,
25-
# GatedGraphConv,
28+
GatedGraphConv,
2629
GCNConv,
27-
# GINConv,
30+
GINConv,
2831
# GMMConv,
2932
GraphConv,
3033
# MEGNetConv,

GNNLux/src/layers/conv.jl

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv)
3838
end
3939

4040
LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
41-
LuxCore.statelength(d::GCNConv) = 0
4241
LuxCore.outputsize(d::GCNConv) = (d.out_dims,)
4342

4443
function Base.show(io::IO, l::GCNConv)
@@ -549,7 +548,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::SGConv)
549548
end
550549

551550
LuxCore.parameterlength(l::SGConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
552-
LuxCore.statelength(d::SGConv) = 0
553551
LuxCore.outputsize(d::SGConv) = (d.out_dims,)
554552

555553
function Base.show(io::IO, l::SGConv)
@@ -561,14 +559,72 @@ function Base.show(io::IO, l::SGConv)
561559
print(io, ")")
562560
end
563561

564-
(l::SGConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing) =
565-
l(g, x, edge_weight, ps, st; conv_weight)
566-
567-
function (l::SGConv)(g, x, edge_weight, ps, st;
568-
conv_weight=nothing, )
562+
(l::SGConv)(g, x, ps, st) = l(g, x, nothing, ps, st)
569563

564+
function (l::SGConv)(g, x, edge_weight, ps, st)
570565
m = (; ps.weight, bias = _getbias(ps),
571566
l.add_self_loops, l.use_edge_weight, l.k)
572567
y = GNNlib.sg_conv(m, g, x, edge_weight)
573568
return y, st
574-
end
569+
end
570+
571+
@concrete struct GatedGraphConv <: GNNLayer
572+
gru
573+
init_weight
574+
dims::Int
575+
num_layers::Int
576+
aggr
577+
end
578+
579+
580+
function GatedGraphConv(dims::Int, num_layers::Int;
581+
aggr = +, init_weight = glorot_uniform)
582+
gru = GRUCell(dims => dims)
583+
return GatedGraphConv(gru, init_weight, dims, num_layers, aggr)
584+
end
585+
586+
LuxCore.outputsize(l::GatedGraphConv) = (l.dims,)
587+
588+
function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv)
589+
gru = LuxCore.initialparameters(rng, l.gru)
590+
weight = l.init_weight(rng, l.dims, l.dims, l.num_layers)
591+
return (; gru, weight)
592+
end
593+
594+
LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l.num_layers
595+
596+
597+
function (l::GatedGraphConv)(g, x, ps, st)
598+
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
599+
fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style
600+
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
601+
return GNNlib.gated_graph_conv(m, g, x), st
602+
end
603+
604+
function Base.show(io::IO, l::GatedGraphConv)
605+
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
606+
print(io, ", aggr=", l.aggr)
607+
print(io, ")")
608+
end
609+
610+
@concrete struct GINConv <: GNNContainerLayer{(:nn,)}
611+
nn <: AbstractExplicitLayer
612+
ϵ <: Real
613+
aggr
614+
end
615+
616+
GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)
617+
618+
function (l::GINConv)(g, x, ps, st)
619+
nn = StatefulLuxLayer{true}(l.nn, ps, st)
620+
m = (; nn, l.ϵ, l.aggr)
621+
y = GNNlib.gin_conv(m, g, x)
622+
stnew = _getstate(nn)
623+
return y, stnew
624+
end
625+
626+
function Base.show(io::IO, l::GINConv)
627+
print(io, "GINConv($(l.nn)")
628+
print(io, ", $(l.ϵ)")
629+
print(io, ")")
630+
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,15 @@
8282
l = SGConv(in_dims => out_dims, 2)
8383
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
8484
end
85-
end
8685

86+
@testset "GatedGraphConv" begin
87+
l = GatedGraphConv(in_dims, 3)
88+
test_lux_layer(rng, l, g, x, outputsize=(in_dims,))
89+
end
90+
91+
@testset "GINConv" begin
92+
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
93+
l = GINConv(nn, 0.5)
94+
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
95+
end
96+
end

GNNLux/test/shared_testsetup.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
2828
@test LuxCore.statelength(l) == LuxCore.statelength(st)
2929

3030
y, st′ = l(g, x, ps, st)
31+
@test eltype(y) == eltype(x)
3132
if outputsize !== nothing
3233
@test LuxCore.outputsize(l) == outputsize
3334
end

GNNlib/src/layers/conv.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function gcn_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, norm_fn::F, conv_w
2828
if edge_weight !== nothing
2929
# Pad weights with ones
3030
# TODO for ADJMAT_T the new edges are not generally at the end
31-
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
31+
edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)]
3232
@assert length(edge_weight) == g.num_edges
3333
end
3434
end
@@ -215,23 +215,22 @@ end
215215

216216
####################### GatedGraphConv ######################################
217217

218-
# TODO PIRACY! remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
219-
@non_differentiable fill!(x...)
220-
221-
function gated_graph_conv(l, g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real}
222-
check_num_nodes(g, H)
223-
m, n = size(H)
224-
@assert (m<=l.out_ch) "number of input features must less or equals to output features."
225-
if m < l.out_ch
226-
Hpad = similar(H, S, l.out_ch - m, n)
227-
H = vcat(H, fill!(Hpad, 0))
218+
function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix)
219+
check_num_nodes(g, x)
220+
m, n = size(x)
221+
@assert m <= l.dims "number of input features must be less or equal to output features."
222+
if m < l.dims
223+
xpad = zeros_like(x, (l.dims - m, n))
224+
x = vcat(x, xpad)
228225
end
226+
h = x
229227
for i in 1:(l.num_layers)
230-
M = view(l.weight, :, :, i) * H
231-
M = propagate(copy_xj, g, l.aggr; xj = M)
232-
H, _ = l.gru(H, M)
228+
m = view(l.weight, :, :, i) * h
229+
m = propagate(copy_xj, g, l.aggr; xj = m)
230+
# in gru forward, hidden state is first argument, input is second
231+
h, _ = l.gru(h, m)
233232
end
234-
return H
233+
return h
235234
end
236235

237236
####################### EdgeConv ######################################
@@ -419,7 +418,7 @@ function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T},
419418
if l.add_self_loops
420419
g = add_self_loops(g)
421420
if edge_weight !== nothing
422-
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
421+
edge_weight = [edge_weight; onse_like(edge_weight, g.num_nodes)]
423422
@assert length(edge_weight) == g.num_edges
424423
end
425424
end
@@ -512,7 +511,7 @@ function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T},
512511
if l.add_self_loops
513512
g = add_self_loops(g)
514513
if edge_weight !== nothing
515-
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
514+
edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)]
516515
@assert length(edge_weight) == g.num_edges
517516
end
518517
end
@@ -644,7 +643,7 @@ function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T},
644643
if l.add_self_loops
645644
g = add_self_loops(g)
646645
if edge_weight !== nothing
647-
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
646+
edge_weight = [edge_weight; ones_like(edge_weight, g.num_nodes)]
648647
@assert length(edge_weight) == g.num_edges
649648
end
650649
end

src/layers/conv.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ where ``\mathbf{h}^{(l)}_i`` denotes the ``l``-th hidden variables passing throu
486486
# Arguments
487487
488488
- `out`: The dimension of output features.
489-
- `num_layers`: The number of gated recurrent unit.
489+
- `num_layers`: The number of recursion steps.
490490
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
491491
- `init`: Weight initialization function.
492492
@@ -510,25 +510,25 @@ y = l(g, x)
510510
struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer
511511
weight::W
512512
gru::R
513-
out_ch::Int
513+
dims::Int
514514
num_layers::Int
515515
aggr::A
516516
end
517517

518518
@functor GatedGraphConv
519519

520-
function GatedGraphConv(out_ch::Int, num_layers::Int;
520+
function GatedGraphConv(dims::Int, num_layers::Int;
521521
aggr = +, init = glorot_uniform)
522-
w = init(out_ch, out_ch, num_layers)
523-
gru = GRUCell(out_ch, out_ch)
524-
GatedGraphConv(w, gru, out_ch, num_layers, aggr)
522+
w = init(dims, dims, num_layers)
523+
gru = GRUCell(dims => dims)
524+
GatedGraphConv(w, gru, dims, num_layers, aggr)
525525
end
526526

527527

528528
(l::GatedGraphConv)(g, H) = GNNlib.gated_graph_conv(l, g, H)
529529

530530
function Base.show(io::IO, l::GatedGraphConv)
531-
print(io, "GatedGraphConv(($(l.out_ch) => $(l.out_ch))^$(l.num_layers)")
531+
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
532532
print(io, ", aggr=", l.aggr)
533533
print(io, ")")
534534
end
@@ -1201,7 +1201,7 @@ function SGConv(ch::Pair{Int, Int}, k = 1;
12011201
in, out = ch
12021202
W = init(out, in)
12031203
b = bias ? Flux.create_bias(W, true, out) : false
1204-
SGConv(W, b, k, add_self_loops, use_edge_weight)
1204+
return SGConv(W, b, k, add_self_loops, use_edge_weight)
12051205
end
12061206

12071207
(l::SGConv)(g, x, edge_weight = nothing) = GNNlib.sg_conv(l, g, x, edge_weight)

0 commit comments

Comments
 (0)