Skip to content

Commit 79515e9

Browse files
create GNNLux.jl package (#460)
* create GNNLux * create GNNLux.jl * fix ci
1 parent cafc1bc commit 79515e9

File tree

10 files changed

+272
-3
lines changed

10 files changed

+272
-3
lines changed

.github/workflows/test_GNNLux.yml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
name: GNNLux
2+
on:
3+
pull_request:
4+
branches:
5+
- master
6+
push:
7+
branches:
8+
- master
9+
jobs:
10+
test:
11+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
version:
17+
- '1.10' # Replace this with the minimum Julia version that your package supports.
18+
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
19+
# - 'pre'
20+
os:
21+
- ubuntu-latest
22+
arch:
23+
- x64
24+
25+
steps:
26+
- uses: actions/checkout@v4
27+
- uses: julia-actions/setup-julia@v2
28+
with:
29+
version: ${{ matrix.version }}
30+
arch: ${{ matrix.arch }}
31+
- uses: julia-actions/cache@v2
32+
- uses: julia-actions/julia-buildpkg@v1
33+
- name: Install Julia dependencies and run tests
34+
shell: julia --project=monorepo {0}
35+
run: |
36+
using Pkg
37+
# dev mono repo versions
38+
pkg"registry up"
39+
Pkg.update()
40+
pkg"dev ./GNNGraphs ./GNNlib ./GNNLux"
41+
Pkg.test("GNNLux"; coverage=true)
42+
- uses: julia-actions/julia-processcoverage@v1
43+
with:
44+
# directories: ./GNNLux/src, ./GNNLux/ext
45+
directories: ./GNNLux/src
46+
- uses: codecov/codecov-action@v4
47+
with:
48+
files: lcov.info

GNNLux/LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 Carlo Lucibello <[email protected]> and contributors
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

GNNLux/Project.toml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name = "GNNLux"
2+
uuid = "e8545f4d-a905-48ac-a8c4-ca114b98986d"
3+
authors = ["Carlo Lucibello and contributors"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
8+
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
9+
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
10+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
11+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
12+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
13+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
15+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
16+
17+
[compat]
18+
ConcreteStructs = "0.2.3"
19+
Lux = "0.5.61"
20+
LuxCore = "0.1.20"
21+
NNlib = "0.9.21"
22+
Reexport = "1.2"
23+
julia = "1.10"
24+
25+
[extras]
26+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
27+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
28+
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
29+
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
30+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
31+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
32+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
33+
34+
[targets]
35+
test = ["Test", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]

GNNLux/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# GNNLux.jl
2+

GNNLux/src/GNNLux.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module GNNLux
2+
using ConcreteStructs: @concrete
3+
using NNlib: NNlib
4+
using LuxCore: LuxCore, AbstractExplicitLayer
5+
using Lux: glorot_uniform, zeros32
6+
using Reexport: @reexport
7+
using Random: AbstractRNG
8+
using GNNlib: GNNlib
9+
@reexport using GNNGraphs
10+
11+
include("layers/conv.jl")
12+
export GraphConv
13+
14+
end #module
15+

GNNLux/src/layers/conv.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
2+
@doc raw"""
3+
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
4+
5+
Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
6+
7+
Performs:
8+
```math
9+
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j
10+
```
11+
12+
where the aggregation type is selected by `aggr`.
13+
14+
# Arguments
15+
16+
- `in`: The dimension of input features.
17+
- `out`: The dimension of output features.
18+
- `σ`: Activation function.
19+
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
20+
- `bias`: Add learnable bias.
21+
- `init`: Weights' initializer.
22+
23+
# Examples
24+
25+
```julia
26+
# create data
27+
s = [1,1,2,3]
28+
t = [2,3,1,1]
29+
in_channel = 3
30+
out_channel = 5
31+
g = GNNGraph(s, t)
32+
x = randn(Float32, 3, g.num_nodes)
33+
34+
# create layer
35+
l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean)
36+
37+
# forward pass
38+
y = l(g, x)
39+
```
40+
"""
41+
@concrete struct GraphConv <: AbstractExplicitLayer
42+
in_dims::Int
43+
out_dims::Int
44+
use_bias::Bool
45+
init_weight::Function
46+
init_bias::Function
47+
σ
48+
aggr
49+
end
50+
51+
52+
function GraphConv(ch::Pair{Int, Int}, σ = identity;
53+
aggr = +,
54+
init_weight = glorot_uniform,
55+
init_bias = zeros32,
56+
use_bias::Bool = true,
57+
allow_fast_activation::Bool = true)
58+
in_dims, out_dims = ch
59+
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
60+
return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr)
61+
end
62+
63+
function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv)
64+
weight1 = l.init_weight(rng, l.out_dims, l.in_dims)
65+
weight2 = l.init_weight(rng, l.out_dims, l.in_dims)
66+
if l.use_bias
67+
bias = l.init_bias(rng, l.out_dims)
68+
else
69+
bias = false
70+
end
71+
return (; weight1, weight2, bias)
72+
end
73+
74+
function LuxCore.parameterlength(l::GraphConv)
75+
if l.use_bias
76+
return 2 * l.in_dims * l.out_dims + l.out_dims
77+
else
78+
return 2 * l.in_dims * l.out_dims
79+
end
80+
end
81+
82+
LuxCore.statelength(d::GraphConv) = 0
83+
LuxCore.outputsize(d::GraphConv) = (d.out_dims,)
84+
85+
function Base.show(io::IO, l::GraphConv)
86+
print(io, "GraphConv(", l.in_dims, " => ", l.out_dims)
87+
(l.σ == identity) || print(io, ", ", l.σ)
88+
(l.aggr == +) || print(io, ", aggr=", l.aggr)
89+
l.use_bias || print(io, ", use_bias=false")
90+
print(io, ")")
91+
end
92+
93+
(l::GraphConv)(g::GNNGraph, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st

GNNLux/test/layers/conv_tests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
@testitem "layers/conv" setup=[SharedTestSetup] begin
2+
rng = StableRNG(1234)
3+
g = rand_graph(10, 30, seed=1234)
4+
x = randn(rng, Float32, 3, 10)
5+
6+
@testset "GraphConv" begin
7+
l = GraphConv(3 => 5, relu)
8+
ps = Lux.initialparameters(rng, l)
9+
st = Lux.initialstates(rng, l)
10+
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
11+
@test Lux.statelength(l) == Lux.statelength(st)
12+
13+
y, _ = l(g, x, ps, st)
14+
@test Lux.outputsize(l) == (5,)
15+
@test size(y) == (5, 10)
16+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
17+
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3
18+
end
19+
end

GNNLux/test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Test
2+
using Lux
3+
using GNNLux
4+
using Random, Statistics
5+
6+
using ReTestItems
7+
# using Pkg, Preferences, Test
8+
# using InteractiveUtils, Hwloc
9+
10+
runtests(GNNLux)

GNNLux/test/shared_testsetup.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@testsetup module SharedTestSetup
2+
3+
import Reexport: @reexport
4+
5+
@reexport using Lux, Functors
6+
@reexport using ComponentArrays, LuxCore, LuxTestUtils, Random, StableRNGs, Test,
7+
Zygote, Statistics
8+
@reexport using LuxTestUtils: @jet, @test_gradients, check_approx
9+
10+
# Some Helper Functions
11+
function get_default_rng(mode::String)
12+
dev = mode == "cpu" ? LuxCPUDevice() :
13+
mode == "cuda" ? LuxCUDADevice() : mode == "amdgpu" ? LuxAMDGPUDevice() : nothing
14+
rng = default_device_rng(dev)
15+
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng)
16+
end
17+
18+
export get_default_rng
19+
20+
# export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng,
21+
# StableRNG, maybe_rewrite_to_crosscor
22+
23+
end

GNNlib/src/layers/conv.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,15 @@ function cheb_conv(c, g::GNNGraph, X::AbstractMatrix{T}) where {T}
9090
return Y .+ c.bias
9191
end
9292

93-
function graph_conv(l, g::AbstractGNNGraph, x)
93+
function graph_conv(l, g::AbstractGNNGraph, x, ps)
9494
check_num_nodes(g, x)
9595
xj, xi = expand_srcdst(g, x)
9696
m = propagate(copy_xj, g, l.aggr, xj = xj)
97-
x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias)
98-
return x
97+
x = ps.weight1 * xi .+ ps.weight2 * m
98+
if l.use_bias
99+
x = x .+ ps.bias
100+
end
101+
return l.σ.(x)
99102
end
100103

101104
function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing)

0 commit comments

Comments
 (0)