Skip to content

Commit 6ed3794

Browse files
create GNNLux
1 parent cafc1bc commit 6ed3794

File tree

9 files changed

+233
-2
lines changed

9 files changed

+233
-2
lines changed

.github/workflows/test_GNNLux.yml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
name: GNNLux
2+
on:
3+
pull_request:
4+
branches:
5+
- master
6+
push:
7+
branches:
8+
- master
9+
jobs:
10+
test:
11+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
version:
17+
- '1.10' # Replace this with the minimum Julia version that your package supports.
18+
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
19+
# - 'pre'
20+
os:
21+
- ubuntu-latest
22+
arch:
23+
- x64
24+
25+
steps:
26+
- uses: actions/checkout@v4
27+
- uses: julia-actions/setup-julia@v2
28+
with:
29+
version: ${{ matrix.version }}
30+
arch: ${{ matrix.arch }}
31+
- uses: julia-actions/cache@v2
32+
- uses: julia-actions/julia-buildpkg@v1
33+
- name: Install Julia dependencies and run tests
34+
shell: julia --project=monorepo {0}
35+
run: |
36+
using Pkg
37+
# dev mono repo versions
38+
pkg"registry up"
39+
Pkg.update()
40+
pkg"dev ./GNNGraphs /GNNlib ./GNNLux"
41+
Pkg.test("GNNLux"; coverage=true)
42+
- uses: julia-actions/julia-processcoverage@v1
43+
with:
44+
# directories: ./GNNLux/src, ./GNNLux/ext
45+
directories: ./GNNLux/src
46+
- uses: codecov/codecov-action@v4
47+
with:
48+
files: lcov.info

GNNLux/LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 Carlo Lucibello <[email protected]> and contributors
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

GNNLux/Project.toml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name = "GNNLux"
2+
uuid = "e8545f4d-a905-48ac-a8c4-ca114b98986d"
3+
authors = ["Carlo Lucibello and contributors"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
8+
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
9+
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
10+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
11+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
12+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
13+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
15+
16+
[extensions]
17+
18+
[compat]
19+
julia = "1.10"
20+
21+
[extras]
22+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
24+
[targets]
25+
test = ["Test"]

GNNLux/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# GNNLux.jl
2+

GNNLux/src/GNNLux.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module GNNLux
2+
using ConcreteStructs: @concrete
3+
using NNlib: NNlib
4+
using LuxCore: LuxCore, AbstractExplicitLayer
5+
using Lux: glorot_uniform, zeros32
6+
using Reexport: @reexport
7+
using Random: AbstractRNG
8+
9+
@reexport using GNNGraphs
10+
11+
include("layers/conv.jl")
12+
export GraphConv
13+
14+
end #module
15+

GNNLux/src/layers/conv.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
2+
@doc raw"""
3+
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
4+
5+
Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
6+
7+
Performs:
8+
```math
9+
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j
10+
```
11+
12+
where the aggregation type is selected by `aggr`.
13+
14+
# Arguments
15+
16+
- `in`: The dimension of input features.
17+
- `out`: The dimension of output features.
18+
- `σ`: Activation function.
19+
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
20+
- `bias`: Add learnable bias.
21+
- `init`: Weights' initializer.
22+
23+
# Examples
24+
25+
```julia
26+
# create data
27+
s = [1,1,2,3]
28+
t = [2,3,1,1]
29+
in_channel = 3
30+
out_channel = 5
31+
g = GNNGraph(s, t)
32+
x = randn(Float32, 3, g.num_nodes)
33+
34+
# create layer
35+
l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean)
36+
37+
# forward pass
38+
y = l(g, x)
39+
```
40+
"""
41+
@concrete struct GraphConv <: AbstractExplicitLayer
42+
in_dims::Int
43+
out_dims::Int
44+
use_bias::Bool
45+
init_weight::Function
46+
init_bias::Function
47+
σ
48+
aggr
49+
end
50+
51+
52+
function GraphConv(ch::Pair{Int, Int}, σ = identity;
53+
aggr = +,
54+
init_weight = glorot_uniform,
55+
init_bias = zeros32,
56+
use_bias::Bool = true,
57+
allow_fast_activation::Bool = true)
58+
in_dims, out_dims = ch
59+
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
60+
return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr)
61+
end
62+
63+
function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv)
64+
weight1 = l.init_weight(rng, l.out_dims, l.in_dims)
65+
weight2 = l.init_weight(rng, l.out_dims, l.in_dims)
66+
if l.use_bias
67+
bias = l.init_bias(rng, l.out_dims)
68+
else
69+
bias = false
70+
end
71+
return (; weight1, weight2, bias)
72+
end
73+
74+
function LuxCore.parameterlength(l::GraphConv)
75+
if l.use_bias
76+
return 2 * l.in_dims * l.out_dims + l.out_dims
77+
else
78+
return 2 * l.in_dims * l.out_dims
79+
end
80+
end
81+
82+
statelength(d::GraphConv) = 0
83+
outputsize(d::GraphConv) = (d.out_dims,)
84+
85+
function Base.show(io::IO, l::GraphConv)
86+
print(io, "GraphConv(", l.in_dims, " => ", l.out_dims)
87+
(l.σ == identity) || print(io, ", ", l.σ)
88+
(l.aggr == +) || print(io, ", aggr=", l.aggr)
89+
l.use_bias || print(io, ", use_bias=false")
90+
print(io, ")")
91+
end
92+
93+
(l::GraphConv)(g::GNNGraph, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st

GNNLux/test/layers/conv.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
@testset "GraphConv" begin
2+
rng = MersenneTwister(1234)
3+
g = rand_graph(10, 20, ndata= rand(Float32, 3, 10))
4+
l = GraphConv(3 => 5, relu)
5+
ps = LuxCore.initialparameters(rng, l)
6+
st = LuxCore.initialstates(rng, l)
7+
end

GNNLux/test/runtests.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using Test
2+
using Lux
3+
using GNNLux
4+
using Random, Statistics
5+
6+
7+
tests = [
8+
# "utils",
9+
# "msgpass",
10+
# "layers/basic",
11+
"layers/conv",
12+
# "layers/heteroconv",
13+
# "layers/temporalconv",
14+
# "layers/pool",
15+
# "examples/node_classification_cora",
16+
]
17+
18+
@testset "$t" for t in tests
19+
include("$t.jl")
20+
end

GNNlib/src/layers/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ function cheb_conv(c, g::GNNGraph, X::AbstractMatrix{T}) where {T}
9090
return Y .+ c.bias
9191
end
9292

93-
function graph_conv(l, g::AbstractGNNGraph, x)
93+
function graph_conv(l, g::AbstractGNNGraph, x, ps)
9494
check_num_nodes(g, x)
9595
xj, xi = expand_srcdst(g, x)
9696
m = propagate(copy_xj, g, l.aggr, xj = xj)
97-
x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias)
97+
x = l.σ.(ps.weight1 * xi .+ ps.weight2 * m .+ ps.bias)
9898
return x
9999
end
100100

0 commit comments

Comments
 (0)