Skip to content

Commit f2add0e

Browse files
create GNNLux.jl
1 parent 6ed3794 commit f2add0e

File tree

8 files changed

+67
-29
lines changed

8 files changed

+67
-29
lines changed

GNNLux/Project.toml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,24 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1212
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
15-
16-
[extensions]
15+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1716

1817
[compat]
18+
ConcreteStructs = "0.2.3"
19+
Lux = "0.5.61"
20+
LuxCore = "0.1.20"
21+
NNlib = "0.9.21"
22+
Reexport = "1.2"
1923
julia = "1.10"
2024

2125
[extras]
26+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
27+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
28+
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
29+
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
30+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2231
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
32+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2333

2434
[targets]
25-
test = ["Test"]
35+
test = ["Test", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]

GNNLux/src/GNNLux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using LuxCore: LuxCore, AbstractExplicitLayer
55
using Lux: glorot_uniform, zeros32
66
using Reexport: @reexport
77
using Random: AbstractRNG
8-
8+
using GNNlib: GNNlib
99
@reexport using GNNGraphs
1010

1111
include("layers/conv.jl")

GNNLux/src/layers/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ function LuxCore.parameterlength(l::GraphConv)
7979
end
8080
end
8181

82-
statelength(d::GraphConv) = 0
83-
outputsize(d::GraphConv) = (d.out_dims,)
82+
LuxCore.statelength(d::GraphConv) = 0
83+
LuxCore.outputsize(d::GraphConv) = (d.out_dims,)
8484

8585
function Base.show(io::IO, l::GraphConv)
8686
print(io, "GraphConv(", l.in_dims, " => ", l.out_dims)

GNNLux/test/layers/conv.jl

Lines changed: 0 additions & 7 deletions
This file was deleted.

GNNLux/test/layers/conv_tests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
@testitem "layers/conv" setup=[SharedTestSetup] begin
2+
rng = StableRNG(1234)
3+
g = rand_graph(10, 30, seed=1234)
4+
x = randn(rng, Float32, 3, 10)
5+
6+
@testset "GraphConv" begin
7+
l = GraphConv(3 => 5, relu)
8+
ps = Lux.initialparameters(rng, l)
9+
st = Lux.initialstates(rng, l)
10+
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
11+
@test Lux.statelength(l) == Lux.statelength(st)
12+
13+
y, _ = l(g, x, ps, st)
14+
@test Lux.outputsize(l) == (5,)
15+
@test size(y) == (5, 10)
16+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
17+
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3
18+
end
19+
end

GNNLux/test/runtests.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,8 @@ using Lux
33
using GNNLux
44
using Random, Statistics
55

6+
using ReTestItems
7+
# using Pkg, Preferences, Test
8+
# using InteractiveUtils, Hwloc
69

7-
tests = [
8-
# "utils",
9-
# "msgpass",
10-
# "layers/basic",
11-
"layers/conv",
12-
# "layers/heteroconv",
13-
# "layers/temporalconv",
14-
# "layers/pool",
15-
# "examples/node_classification_cora",
16-
]
17-
18-
@testset "$t" for t in tests
19-
include("$t.jl")
20-
end
10+
runtests(GNNLux)

GNNLux/test/shared_testsetup.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@testsetup module SharedTestSetup
2+
3+
import Reexport: @reexport
4+
5+
@reexport using Lux, Functors
6+
@reexport using ComponentArrays, LuxCore, LuxTestUtils, Random, StableRNGs, Test,
7+
Zygote, Statistics
8+
@reexport using LuxTestUtils: @jet, @test_gradients, check_approx
9+
10+
# Some Helper Functions
11+
function get_default_rng(mode::String)
12+
dev = mode == "cpu" ? LuxCPUDevice() :
13+
mode == "cuda" ? LuxCUDADevice() : mode == "amdgpu" ? LuxAMDGPUDevice() : nothing
14+
rng = default_device_rng(dev)
15+
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng)
16+
end
17+
18+
export get_default_rng
19+
20+
# export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng,
21+
# StableRNG, maybe_rewrite_to_crosscor
22+
23+
end

GNNlib/src/layers/conv.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ function graph_conv(l, g::AbstractGNNGraph, x, ps)
9494
check_num_nodes(g, x)
9595
xj, xi = expand_srcdst(g, x)
9696
m = propagate(copy_xj, g, l.aggr, xj = xj)
97-
x = l.σ.(ps.weight1 * xi .+ ps.weight2 * m .+ ps.bias)
98-
return x
97+
x = ps.weight1 * xi .+ ps.weight2 * m
98+
if l.use_bias
99+
x = x .+ ps.bias
100+
end
101+
return l.σ.(x)
99102
end
100103

101104
function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing)

0 commit comments

Comments
 (0)