Skip to content

Commit 025d783

Browse files
more works
1 parent 68251a9 commit 025d783

File tree

7 files changed

+105
-97
lines changed

7 files changed

+105
-97
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ export AGNNConv,
4141
# TransformerConv
4242

4343
include("layers/temporalconv.jl")
44-
export TGCN,
45-
A3TGCN,
46-
GConvGRU,
47-
GConvLSTM,
48-
DCGRU,
49-
EvolveGCNO
44+
export GNNRecurrence,
45+
GConvGRU, GConvGRUCell,
46+
GConvLSTM, GConvLSTMCell,
47+
DCGRU, DCGRUCell,
48+
EvolveGCNO, EvolveGCNOCell,
49+
TGCN, TGCNCell
5050

5151
include("layers/pool.jl")
5252
export GlobalPool,

GNNLux/src/layers/basic.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ abstract type GNNLayer <: AbstractLuxLayer end
1010

1111
abstract type GNNContainerLayer{T} <: AbstractLuxContainerLayer{T} end
1212

13+
const AbstractGNNLayer = Union{GNNLayer, GNNContainerLayer}
14+
1315
"""
1416
GNNChain(layers...)
1517
GNNChain(name = layer, ...)
@@ -104,3 +106,22 @@ _applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;)
104106
_applylayer(l::AbstractLuxLayer, g::GNNGraph, x, ps, st) = l(x, ps, st)
105107
_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
106108
_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
109+
110+
111+
# Facilitate using GNNlib functions with Lux layers
112+
# by returning a StatefulLuxLayer when accessing properties
113+
function Base.getproperty(l::StatefulLuxLayer{ST,<:AbstractGNNLayer}, name::Symbol) where ST
114+
hasfield(typeof(l), name) && return getfield(l, name)
115+
f = getproperty(l.model, name)
116+
if f isa AbstractLuxLayer
117+
stf = getproperty(Lux.get_state(l), name)
118+
psf = getproperty(l.ps, name)
119+
if ST === Static.True
120+
return StatefulLuxLayer{true}(f, psf, stf)
121+
else
122+
return StatefulLuxLayer{false}(f, psf, stf)
123+
end
124+
else
125+
return f
126+
end
127+
end

GNNLux/src/layers/temporalconv.jl

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -421,22 +421,6 @@ function DCGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = gl
421421
return DCGRUCell(in_dims, out_dims, k, dconv_u, dconv_r, dconv_c, init_state)
422422
end
423423

424-
# function (l::DCGRUCell)(g, (x, h), ps, st)
425-
# if h === nothing
426-
# h = l.init_state(l.out_dims, g.num_nodes)
427-
# end
428-
# h̃ = vcat(x, h)
429-
# z, st_dconv_u = l.dconv_u(g, h̃, ps.dconv_u, st.dconv_u)
430-
# z = NNlib.sigmoid_fast.(z)
431-
# r, st_dconv_r = l.dconv_r(g, h̃, ps.dconv_r, st.dconv_r)
432-
# r = NNlib.sigmoid_fast.(r)
433-
# ĥ = vcat(x, h .* r)
434-
# c, st_dconv_c = l.dconv_c(g, ĥ, ps.dconv_c, st.dconv_c)
435-
# c = NNlib.tanh_fast.(c)
436-
# h = z.* h + (1 .- z).* c
437-
# return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
438-
# end
439-
440424

441425
function (l::DCGRUCell)(g, x::AbstractMatrix, ps, st)
442426
h = l.init_state(l.out_dims, g.num_nodes)
@@ -445,7 +429,7 @@ end
445429

446430
function (l::DCGRUCell)(g, (x, (h,))::Tuple, ps, st)
447431
m = StatefulLuxLayer{true}(l, ps, st)
448-
h, _ = dcgrucell_frwd(m, g, x, h)
432+
h, _ = GNNlib.dcgrucell_frwd(m, g, x, h)
449433
return (h, (h,)), _getstate(m)
450434
end
451435

GNNLux/test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1313
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
16+
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
1617
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1718
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1819
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

GNNLux/test/test_module.jl

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,5 @@
11
@testmodule TestModuleLux begin
22

3-
using Pkg
4-
5-
## Uncomment below to change the default test settings
6-
# ENV["GNN_TEST_CUDA"] = "true"
7-
# ENV["GNN_TEST_AMDGPU"] = "true"
8-
# ENV["GNN_TEST_Metal"] = "true"
9-
10-
to_test(backend) = get(ENV, "GNN_TEST_$(backend)", "false") == "true"
11-
has_dependecies(pkgs) = all(pkg -> haskey(Pkg.project().dependencies, pkg), pkgs)
12-
deps_dict = Dict(:CUDA => ["CUDA", "cuDNN"], :AMDGPU => ["AMDGPU"], :Metal => ["Metal"])
13-
14-
for (backend, deps) in deps_dict
15-
if to_test(backend)
16-
if !has_dependecies(deps)
17-
Pkg.add(deps)
18-
end
19-
@eval using $backend
20-
if backend == :CUDA
21-
@eval using cuDNN
22-
end
23-
@eval $backend.allowscalar(false)
24-
end
25-
end
26-
27-
using Reexport: @reexport
28-
29-
@reexport using Test
30-
@reexport using GNNLux
31-
@reexport using Lux
32-
@reexport using StableRNGs
33-
@reexport using Random, Statistics
34-
35-
using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
36-
37-
export test_lux_layer
38-
39-
function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
40-
outputsize=nothing, sizey=nothing, container=false,
41-
atol=1.0f-2, rtol=1.0f-2, e=nothing)
42-
43-
if container
44-
@test l isa GNNContainerLayer
45-
else
46-
@test l isa GNNLayer
47-
end
48-
49-
ps = LuxCore.initialparameters(rng, l)
50-
st = LuxCore.initialstates(rng, l)
51-
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
52-
@test LuxCore.statelength(l) == LuxCore.statelength(st)
53-
54-
if e !== nothing
55-
y, st′ = l(g, x, e, ps, st)
56-
else
57-
y, st′ = l(g, x, ps, st)
58-
end
59-
@test eltype(y) == eltype(x)
60-
if outputsize !== nothing
61-
@test LuxCore.outputsize(l) == outputsize
62-
end
63-
if sizey !== nothing
64-
@test size(y) == sizey
65-
elseif outputsize !== nothing
66-
@test size(y) == (outputsize..., g.num_nodes)
67-
end
68-
69-
if e !== nothing
70-
loss = (x, ps) -> sum(first(l(g, x, e, ps, st)))
71-
else
72-
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
73-
end
74-
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
75-
end
3+
include("test_utils.jl")
764

775
end

GNNLux/test/test_utils.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using Pkg
2+
3+
## Uncomment below to change the default test settings
4+
# ENV["GNN_TEST_CUDA"] = "true"
5+
# ENV["GNN_TEST_AMDGPU"] = "true"
6+
# ENV["GNN_TEST_Metal"] = "true"
7+
8+
to_test(backend) = get(ENV, "GNN_TEST_$(backend)", "false") == "true"
9+
has_dependecies(pkgs) = all(pkg -> haskey(Pkg.project().dependencies, pkg), pkgs)
10+
deps_dict = Dict(:CUDA => ["CUDA", "cuDNN"], :AMDGPU => ["AMDGPU"], :Metal => ["Metal"])
11+
12+
for (backend, deps) in deps_dict
13+
if to_test(backend)
14+
if !has_dependecies(deps)
15+
Pkg.add(deps)
16+
end
17+
@eval using $backend
18+
if backend == :CUDA
19+
@eval using cuDNN
20+
end
21+
@eval $backend.allowscalar(false)
22+
end
23+
end
24+
25+
using Reexport: @reexport
26+
27+
@reexport using Test
28+
@reexport using GNNLux
29+
@reexport using Lux
30+
@reexport using StableRNGs
31+
@reexport using Random, Statistics
32+
33+
using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
34+
35+
export test_lux_layer
36+
37+
function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
38+
outputsize=nothing, sizey=nothing, container=false,
39+
atol=1.0f-2, rtol=1.0f-2, e=nothing)
40+
41+
if container
42+
@test l isa GNNContainerLayer
43+
else
44+
@test l isa GNNLayer
45+
end
46+
47+
ps = LuxCore.initialparameters(rng, l)
48+
st = LuxCore.initialstates(rng, l)
49+
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
50+
@test LuxCore.statelength(l) == LuxCore.statelength(st)
51+
52+
if e !== nothing
53+
y, st′ = l(g, x, e, ps, st)
54+
else
55+
y, st′ = l(g, x, ps, st)
56+
end
57+
@test eltype(y) == eltype(x)
58+
if outputsize !== nothing
59+
@test LuxCore.outputsize(l) == outputsize
60+
end
61+
if sizey !== nothing
62+
@test size(y) == sizey
63+
elseif outputsize !== nothing
64+
@test size(y) == (outputsize..., g.num_nodes)
65+
end
66+
67+
if e !== nothing
68+
loss = (x, ps) -> sum(first(l(g, x, e, ps, st)))
69+
else
70+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
71+
end
72+
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
73+
end
74+

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ julia> size(y) # (d_out, num_nodes)
516516
(3, 5)
517517
```
518518
"""
519-
struct DCGRUCell
519+
struct DCGRUCell <: GNNLayer
520520
in::Int
521521
out::Int
522522
k::Int

0 commit comments

Comments
 (0)