Skip to content

Commit 80c672a

Browse files
[GNNLux] more layers (#463)
* move to MLDataDevices * cg_conv * edgeconv working * cleanup
1 parent 3b42087 commit 80c672a

File tree

12 files changed

+221
-81
lines changed

12 files changed

+221
-81
lines changed

GNNGraphs/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1010
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13-
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
13+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1414
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1515
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1616
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
@@ -35,7 +35,7 @@ Functors = "0.4.1"
3535
Graphs = "1.4"
3636
KrylovKit = "0.8"
3737
LinearAlgebra = "1"
38-
LuxDeviceUtils = "0.1.24"
38+
MLDataDevices = "1.0"
3939
MLDatasets = "0.7"
4040
MLUtils = "0.4"
4141
NNlib = "0.9"

GNNGraphs/src/GNNGraphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using LinearAlgebra, Random, Statistics
1414
import MLUtils
1515
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like
1616
import Functors
17-
using LuxDeviceUtils: get_device, cpu_device, LuxCPUDevice
17+
using MLDataDevices: get_device, cpu_device, CPUDevice
1818

1919
include("chainrules.jl") # hacks for differentiability
2020

GNNGraphs/test/gnngraph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ end
5757
# core functionality
5858
g = GNNGraph(s, t; graph_type = GRAPH_T)
5959
if TEST_GPU
60-
dev = LuxCUDADevice() #TODO replace with gpu_device()
60+
dev = CUDADevice()
6161
g_gpu = g |> dev
6262
end
6363

@@ -141,7 +141,7 @@ end
141141
# core functionality
142142
g = GNNGraph(s, t; graph_type = GRAPH_T)
143143
if TEST_GPU
144-
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
144+
dev = CUDADevice() #TODO replace with `gpu_device()`
145145
g_gpu = g |> dev
146146
end
147147

GNNGraphs/test/query.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
@test eltype(degree(g, Float32)) == Float32
6060

6161
if TEST_GPU
62-
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
62+
dev = CUDADevice() #TODO replace with `gpu_device()`
6363
g_gpu = g |> dev
6464
d = degree(g)
6565
d_gpu = degree(g_gpu)
@@ -87,7 +87,7 @@ end
8787
@test degree(g, edge_weight = 2 * eweight) [4.4, 2.4, 2.0, 0.0] broken = (GRAPH_T != :coo)
8888

8989
if TEST_GPU
90-
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
90+
dev = CUDADevice() #TODO replace with `gpu_device()`
9191
g_gpu = g |> dev
9292
d = degree(g)
9393
d_gpu = degree(g_gpu)

GNNGraphs/test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ using Test
1313
using MLDatasets
1414
using InlineStrings # not used but with the import we test #98 and #104
1515
using SimpleWeightedGraphs
16-
using LuxDeviceUtils: gpu_device, cpu_device, get_device
17-
using LuxDeviceUtils: LuxCUDADevice # remove after https://github.com/LuxDL/LuxDeviceUtils.jl/pull/58
16+
using MLDataDevices: gpu_device, cpu_device, get_device
17+
using MLDataDevices: CUDADevice
1818

1919
CUDA.allowscalar(false)
2020

GNNGraphs/test/temporalsnapshotsgnngraph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ if TEST_GPU
107107
snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5]
108108
tsg = TemporalSnapshotsGNNGraph(snapshots)
109109
tsg.tgdata.x = rand(5)
110-
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
110+
dev = CUDADevice() #TODO replace with `gpu_device()`
111111
tsg = tsg |> dev
112112
@test tsg.snapshots[1].ndata.x isa CuArray
113113
@test tsg.snapshots[end].ndata.x isa CuArray

GNNLux/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ julia = "1.10"
2626
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2727
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
2828
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
29+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
2930
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
3031
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3132
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3233
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3334

3435
[targets]
35-
test = ["Test", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]
36+
test = ["Test", "MLDataDevices", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]

GNNLux/src/GNNLux.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
module GNNLux
22
using ConcreteStructs: @concrete
3-
using NNlib: NNlib
3+
using NNlib: NNlib, sigmoid, relu
44
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
5-
using Lux: Lux, glorot_uniform, zeros32
5+
using Lux: Lux, Dense, glorot_uniform, zeros32, StatefulLuxLayer
66
using Reexport: @reexport
77
using Random: AbstractRNG
88
using GNNlib: GNNlib
@@ -14,9 +14,27 @@ export GNNLayer,
1414
GNNChain
1515

1616
include("layers/conv.jl")
17-
export GCNConv,
17+
export AGNNConv,
18+
CGConv,
1819
ChebConv,
20+
EdgeConv,
21+
# EGNNConv,
22+
# DConv,
23+
# GATConv,
24+
# GATv2Conv,
25+
# GatedGraphConv,
26+
GCNConv,
27+
# GINConv,
28+
# GMMConv,
1929
GraphConv
30+
# MEGNetConv,
31+
# NNConv,
32+
# ResGatedGraphConv,
33+
# SAGEConv,
34+
# SGConv,
35+
# TAGConv,
36+
# TransformerConv
37+
2038

2139
end #module
2240

GNNLux/src/layers/conv.jl

Lines changed: 121 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,7 @@
1-
# Missing Layers
2-
3-
# | Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
4-
# | :-------- | :---: |:---: |:---: | :---: | :---: |
5-
# | [`AGNNConv`](@ref) | | | ✓ | | |
6-
# | [`CGConv`](@ref) | | | ✓ | ✓ | ✓ |
7-
# | [`EGNNConv`](@ref) | | | ✓ | | |
8-
# | [`EdgeConv`](@ref) | | | | ✓ | |
9-
# | [`GATConv`](@ref) | | | ✓ | ✓ | ✓ |
10-
# | [`GATv2Conv`](@ref) | | | ✓ | ✓ | ✓ |
11-
# | [`GatedGraphConv`](@ref) | ✓ | | | | ✓ |
12-
# | [`GINConv`](@ref) | ✓ | | | ✓ | ✓ |
13-
# | [`GMMConv`](@ref) | | | ✓ | | |
14-
# | [`MEGNetConv`](@ref) | | | ✓ | | |
15-
# | [`NNConv`](@ref) | | | ✓ | | |
16-
# | [`ResGatedGraphConv`](@ref) | | | | ✓ | ✓ |
17-
# | [`SAGEConv`](@ref) | ✓ | | | ✓ | ✓ |
18-
# | [`SGConv`](@ref) | ✓ | | | | ✓ |
19-
# | [`TransformerConv`](@ref) | | | ✓ | | |
1+
_getbias(ps) = hasproperty(ps, :bias) ? getproperty(ps, :bias) : false
2+
_getstate(st, name) = hasproperty(st, name) ? getproperty(st, name) : NamedTuple()
3+
_getstate(s::StatefulLuxLayer{true}) = s.st
4+
_getstate(s::StatefulLuxLayer{false}) = s.st_any
205

216

227
@concrete struct GCNConv <: GNNLayer
@@ -65,13 +50,18 @@ function Base.show(io::IO, l::GCNConv)
6550
print(io, ")")
6651
end
6752

68-
# TODO norm_fn should be keyword argument only
6953
(l::GCNConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing, norm_fn= d -> 1 ./ sqrt.(d)) =
70-
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
71-
(l::GCNConv)(g, x, edge_weight, ps, st; conv_weight=nothing, norm_fn = d -> 1 ./ sqrt.(d)) =
72-
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
73-
(l::GCNConv)(g, x, edge_weight, norm_fn, ps, st; conv_weight=nothing) =
74-
GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight, ps), st
54+
l(g, x, edge_weight, ps, st; conv_weight, norm_fn)
55+
56+
function (l::GCNConv)(g, x, edge_weight, ps, st;
57+
norm_fn = d -> 1 ./ sqrt.(d),
58+
conv_weight=nothing, )
59+
60+
m = (; ps.weight, bias = _getbias(ps),
61+
l.add_self_loops, l.use_edge_weight, l.σ)
62+
y = GNNlib.gcn_conv(m, g, x, edge_weight, norm_fn, conv_weight)
63+
return y, st
64+
end
7565

7666
@concrete struct ChebConv <: GNNLayer
7767
in_dims::Int
@@ -80,17 +70,14 @@ end
8070
k::Int
8171
init_weight
8272
init_bias
83-
σ
8473
end
8574

86-
function ChebConv(ch::Pair{Int, Int}, k::Int, σ = identity;
75+
function ChebConv(ch::Pair{Int, Int}, k::Int;
8776
init_weight = glorot_uniform,
8877
init_bias = zeros32,
89-
use_bias::Bool = true,
90-
allow_fast_activation::Bool = true)
78+
use_bias::Bool = true)
9179
in_dims, out_dims = ch
92-
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
93-
return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias, σ)
80+
return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias)
9481
end
9582

9683
function LuxCore.initialparameters(rng::AbstractRNG, l::ChebConv)
@@ -109,13 +96,17 @@ LuxCore.statelength(d::ChebConv) = 0
10996
LuxCore.outputsize(d::ChebConv) = (d.out_dims,)
11097

11198
function Base.show(io::IO, l::ChebConv)
112-
print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", K=", l.K)
113-
l.σ == identity || print(io, ", ", l.σ)
99+
print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", k=", l.k)
114100
l.use_bias || print(io, ", use_bias=false")
115101
print(io, ")")
116102
end
117103

118-
(l::ChebConv)(g, x, ps, st) = GNNlib.cheb_conv(l, g, x, ps), st
104+
function (l::ChebConv)(g, x, ps, st)
105+
m = (; ps.weight, bias = _getbias(ps), l.k)
106+
y = GNNlib.cheb_conv(m, g, x)
107+
return y, st
108+
109+
end
119110

120111
@concrete struct GraphConv <: GNNLayer
121112
in_dims::Int
@@ -168,4 +159,99 @@ function Base.show(io::IO, l::GraphConv)
168159
print(io, ")")
169160
end
170161

171-
(l::GraphConv)(g, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
162+
function (l::GraphConv)(g, x, ps, st)
163+
m = (; ps.weight1, ps.weight2, bias = _getbias(ps),
164+
l.σ, l.aggr)
165+
return GNNlib.graph_conv(m, g, x), st
166+
end
167+
168+
169+
@concrete struct AGNNConv <: GNNLayer
170+
init_beta <: AbstractVector
171+
add_self_loops::Bool
172+
trainable::Bool
173+
end
174+
175+
function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true)
176+
return AGNNConv([init_beta], add_self_loops, trainable)
177+
end
178+
179+
function LuxCore.initialparameters(rng::AbstractRNG, l::AGNNConv)
180+
if l.trainable
181+
return (; β = l.init_beta)
182+
else
183+
return (;)
184+
end
185+
end
186+
187+
LuxCore.parameterlength(l::AGNNConv) = l.trainable ? 1 : 0
188+
LuxCore.statelength(d::AGNNConv) = 0
189+
190+
function Base.show(io::IO, l::AGNNConv)
191+
print(io, "AGNNConv(", l.init_beta)
192+
l.add_self_loops || print(io, ", add_self_loops=false")
193+
l.trainable || print(io, ", trainable=false")
194+
print(io, ")")
195+
end
196+
197+
function (l::AGNNConv)(g, x::AbstractMatrix, ps, st)
198+
β = l.trainable ? ps.β : l.init_beta
199+
m = (; β, l.add_self_loops)
200+
return GNNlib.agnn_conv(m, g, x), st
201+
end
202+
203+
@concrete struct CGConv <: GNNContainerLayer{(:dense_f, :dense_s)}
204+
in_dims::NTuple{2, Int}
205+
out_dims::Int
206+
dense_f
207+
dense_s
208+
residual::Bool
209+
init_weight
210+
init_bias
211+
end
212+
213+
CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...)
214+
215+
function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
216+
use_bias = true, init_weight = glorot_uniform, init_bias = zeros32,
217+
allow_fast_activation = true)
218+
(nin, ein), out = ch
219+
dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias, allow_fast_activation)
220+
dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias, allow_fast_activation)
221+
return CGConv((nin, ein), out, dense_f, dense_s, residual, init_weight, init_bias)
222+
end
223+
224+
LuxCore.outputsize(l::CGConv) = (l.out_dims,)
225+
226+
(l::CGConv)(g, x, ps, st) = l(g, x, nothing, ps, st)
227+
228+
function (l::CGConv)(g, x, e, ps, st)
229+
dense_f = StatefulLuxLayer{true}(l.dense_f, ps.dense_f, _getstate(st, :dense_f))
230+
dense_s = StatefulLuxLayer{true}(l.dense_s, ps.dense_s, _getstate(st, :dense_s))
231+
m = (; dense_f, dense_s, l.residual)
232+
return GNNlib.cg_conv(m, g, x, e), st
233+
end
234+
235+
@concrete struct EdgeConv <: GNNContainerLayer{(:nn,)}
236+
nn <: AbstractExplicitLayer
237+
aggr
238+
end
239+
240+
EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr)
241+
242+
function Base.show(io::IO, l::EdgeConv)
243+
print(io, "EdgeConv(", l.nn)
244+
print(io, ", aggr=", l.aggr)
245+
print(io, ")")
246+
end
247+
248+
249+
function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
250+
nn = StatefulLuxLayer{true}(l.nn, ps, st)
251+
m = (; nn, l.aggr)
252+
y = GNNlib.edge_conv(m, g, x)
253+
stnew = _getstate(nn)
254+
return y, stnew
255+
end
256+
257+

GNNLux/test/layers/conv_tests.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
end
2020

2121
@testset "ChebConv" begin
22-
l = ChebConv(3 => 5, 2, relu)
22+
l = ChebConv(3 => 5, 2)
2323
@test l isa GNNLayer
2424
ps = Lux.initialparameters(rng, l)
2525
st = Lux.initialstates(rng, l)
@@ -47,4 +47,47 @@
4747
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
4848
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
4949
end
50+
51+
@testset "AGNNConv" begin
52+
l = AGNNConv(init_beta=1.0f0)
53+
@test l isa GNNLayer
54+
ps = Lux.initialparameters(rng, l)
55+
st = Lux.initialstates(rng, l)
56+
@test Lux.parameterlength(ps) == 1
57+
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
58+
@test Lux.statelength(l) == Lux.statelength(st)
59+
60+
y, _ = l(g, x, ps, st)
61+
@test size(y) == size(x)
62+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
63+
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
64+
end
65+
66+
@testset "EdgeConv" begin
67+
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
68+
l = EdgeConv(nn, aggr = +)
69+
@test l isa GNNContainerLayer
70+
ps = Lux.initialparameters(rng, l)
71+
st = Lux.initialstates(rng, l)
72+
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
73+
@test Lux.statelength(l) == Lux.statelength(st)
74+
y, st′ = l(g, x, ps, st)
75+
@test size(y) == (5, 10)
76+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
77+
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
78+
end
79+
80+
@testset "CGConv" begin
81+
l = CGConv(3 => 5, residual = true)
82+
@test l isa GNNContainerLayer
83+
ps = Lux.initialparameters(rng, l)
84+
st = Lux.initialstates(rng, l)
85+
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
86+
@test Lux.statelength(l) == Lux.statelength(st)
87+
y, st′ = l(g, x, ps, st)
88+
@test size(y) == (5, 10)
89+
@test Lux.outputsize(l) == (5,)
90+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
91+
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
92+
end
5093
end

0 commit comments

Comments
 (0)