Skip to content

Commit 5715b26

Browse files
[GNNLux] updates for Lux v1.0 (#490)
* updates for Lux 1.0 * naming
1 parent bd5e2f2 commit 5715b26

File tree

7 files changed

+40
-38
lines changed

7 files changed

+40
-38
lines changed

GNNLux/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1212
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
15+
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617

1718
[compat]
1819
ConcreteStructs = "0.2.3"
19-
Lux = "0.5.61"
20-
LuxCore = "0.1.20"
20+
Lux = "1.0"
21+
LuxCore = "1.0"
2122
NNlib = "0.9.21"
2223
Reexport = "1.2"
24+
Static = "1.1"
2325
julia = "1.10"
2426

2527
[extras]

GNNLux/src/GNNLux.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@ module GNNLux
22
using ConcreteStructs: @concrete
33
using NNlib: NNlib, sigmoid, relu, swish
44
using Statistics: mean
5-
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
5+
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, parameterlength, statelength, outputsize,
66
initialparameters, initialstates, parameterlength, statelength
77
using Lux: Lux, Chain, Dense, GRUCell,
88
glorot_uniform, zeros32,
99
StatefulLuxLayer
1010
using Reexport: @reexport
1111
using Random: AbstractRNG
1212
using GNNlib: GNNlib
13+
using Static
1314
@reexport using GNNGraphs
1415

1516
include("layers/basic.jl")

GNNLux/src/layers/basic.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""
2-
abstract type GNNLayer <: AbstractExplicitLayer end
2+
abstract type GNNLayer <: AbstractLuxLayer end
33
44
An abstract type from which graph neural network layers are derived.
5-
It is Derived from Lux's `AbstractExplicitLayer` type.
5+
It is Derived from Lux's `AbstractLuxLayer` type.
66
77
See also [`GNNChain`](@ref GNNLux.GNNChain).
88
"""
9-
abstract type GNNLayer <: AbstractExplicitLayer end
9+
abstract type GNNLayer <: AbstractLuxLayer end
1010

11-
abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end
11+
abstract type GNNContainerLayer{T} <: AbstractLuxContainerLayer{T} end
1212

1313
@concrete struct GNNChain <: GNNContainerLayer{(:layers,)}
1414
layers <: NamedTuple
@@ -24,7 +24,7 @@ function GNNChain(; kw...)
2424
return GNNChain(nt)
2525
end
2626

27-
_wrapforchain(l::AbstractExplicitLayer) = l
27+
_wrapforchain(l::AbstractLuxLayer) = l
2828
_wrapforchain(l) = Lux.WrappedFunction(l)
2929

3030
Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers))
@@ -44,7 +44,7 @@ Base.firstindex(c::GNNChain) = firstindex(c.layers)
4444

4545
LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end])
4646

47-
(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st)
47+
(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps.layers, st.layers)
4848

4949
function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times
5050
newst = (;)
@@ -56,6 +56,6 @@ function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, help
5656
end
5757

5858
_applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;)
59-
_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st)
59+
_applylayer(l::AbstractLuxLayer, g::GNNGraph, x, ps, st) = l(x, ps, st)
6060
_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
6161
_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)

GNNLux/src/layers/conv.jl

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
_getbias(ps) = hasproperty(ps, :bias) ? getproperty(ps, :bias) : false
22
_getstate(st, name) = hasproperty(st, name) ? getproperty(st, name) : NamedTuple()
33
_getstate(s::StatefulLuxLayer{true}) = s.st
4+
_getstate(s::StatefulLuxLayer{Static.True}) = s.st
45
_getstate(s::StatefulLuxLayer{false}) = s.st_any
6+
_getstate(s::StatefulLuxLayer{Static.False}) = s.st_any
57

68

79
@concrete struct GCNConv <: GNNLayer
@@ -20,10 +22,9 @@ function GCNConv(ch::Pair{Int, Int}, σ = identity;
2022
init_bias = zeros32,
2123
use_bias::Bool = true,
2224
add_self_loops::Bool = true,
23-
use_edge_weight::Bool = false,
24-
allow_fast_activation::Bool = true)
25+
use_edge_weight::Bool = false)
2526
in_dims, out_dims = ch
26-
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
27+
σ = NNlib.fast_act(σ)
2728
return GCNConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ)
2829
end
2930

@@ -121,10 +122,9 @@ function GraphConv(ch::Pair{Int, Int}, σ = identity;
121122
aggr = +,
122123
init_weight = glorot_uniform,
123124
init_bias = zeros32,
124-
use_bias::Bool = true,
125-
allow_fast_activation::Bool = true)
125+
use_bias::Bool = true)
126126
in_dims, out_dims = ch
127-
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
127+
σ = NNlib.fast_act(σ)
128128
return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr)
129129
end
130130

@@ -212,11 +212,10 @@ end
212212
CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...)
213213

214214
function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
215-
use_bias = true, init_weight = glorot_uniform, init_bias = zeros32,
216-
allow_fast_activation = true)
215+
use_bias = true, init_weight = glorot_uniform, init_bias = zeros32)
217216
(nin, ein), out = ch
218-
dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias, allow_fast_activation)
219-
dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias, allow_fast_activation)
217+
dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias)
218+
dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias)
220219
return CGConv((nin, ein), out, dense_f, dense_s, residual, init_weight, init_bias)
221220
end
222221

@@ -232,7 +231,7 @@ function (l::CGConv)(g, x, e, ps, st)
232231
end
233232

234233
@concrete struct EdgeConv <: GNNContainerLayer{(:nn,)}
235-
nn <: AbstractExplicitLayer
234+
nn <: AbstractLuxLayer
236235
aggr
237236
end
238237

@@ -246,10 +245,10 @@ end
246245

247246

248247
function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
249-
nn = StatefulLuxLayer{true}(l.nn, ps, st)
248+
nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn)
250249
m = (; nn, l.aggr)
251250
y = GNNlib.edge_conv(m, g, x)
252-
stnew = _getstate(nn)
251+
stnew = (; nn = _getstate(nn)) # TODO: support also aggr state if present
253252
return y, stnew
254253
end
255254

@@ -608,18 +607,18 @@ function Base.show(io::IO, l::GatedGraphConv)
608607
end
609608

610609
@concrete struct GINConv <: GNNContainerLayer{(:nn,)}
611-
nn <: AbstractExplicitLayer
610+
nn <: AbstractLuxLayer
612611
ϵ <: Real
613612
aggr
614613
end
615614

616615
GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)
617616

618617
function (l::GINConv)(g, x, ps, st)
619-
nn = StatefulLuxLayer{true}(l.nn, ps, st)
618+
nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn)
620619
m = (; nn, l.ϵ, l.aggr)
621620
y = GNNlib.gin_conv(m, g, x)
622-
stnew = _getstate(nn)
621+
stnew = (; nn = _getstate(nn))
623622
return y, stnew
624623
end
625624

@@ -669,4 +668,4 @@ function Base.show(io::IO, l::MEGNetConv)
669668
nout = l.out_dims
670669
print(io, "MEGNetConv(", nin, " => ", nout)
671670
print(io, ")")
672-
end
671+
end

GNNLux/src/layers/temporalconv.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
1+
@concrete struct StatefulRecurrentCell <: AbstractLuxContainerLayer{(:cell,)}
22
cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer}
33
end
44

@@ -7,16 +7,16 @@ function LuxCore.initialstates(rng::AbstractRNG, r::GNNLux.StatefulRecurrentCell
77
end
88

99
function (r::StatefulRecurrentCell)(g, x::AbstractMatrix, ps, st::NamedTuple)
10-
(out, carry), st = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry)
10+
(out, carry), st = applyrecurrentcell(r.cell, g, x, ps.cell, st.cell, st.carry)
1111
return out, (; cell=st, carry)
1212
end
1313

1414
function (r::StatefulRecurrentCell)(g, x::AbstractVector, ps, st::NamedTuple)
15-
st, carry = st.cell, st.carry
15+
stcell, carry = st.cell, st.carry
1616
for xᵢ in x
17-
(out, carry), st = applyrecurrentcell(r.cell, g, xᵢ, ps, st, carry)
17+
(out, carry), stcell = applyrecurrentcell(r.cell, g, xᵢ, ps.cell, stcell, carry)
1818
end
19-
return out, (; cell=st, carry)
19+
return out, (; cell=stcell, carry)
2020
end
2121

2222
function applyrecurrentcell(l, g, x, ps, st, carry)
@@ -35,7 +35,7 @@ end
3535

3636
function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
3737
in_dims, out_dims = ch
38-
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true)
38+
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
3939
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
4040
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
4141
end

GNNLux/test/layers/basic_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
x = randn(rng, Float32, 3, 10)
55

66
@testset "GNNLayer" begin
7-
@test GNNLayer <: LuxCore.AbstractExplicitLayer
7+
@test GNNLayer <: LuxCore.AbstractLuxLayer
88
end
99

1010
@testset "GNNContainerLayer" begin
11-
@test GNNContainerLayer <: LuxCore.AbstractExplicitContainerLayer
11+
@test GNNContainerLayer <: LuxCore.AbstractLuxContainerLayer
1212
end
1313

1414
@testset "GNNChain" begin
15-
@test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)}
16-
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3))
15+
@test GNNChain <: LuxCore.AbstractLuxContainerLayer{(:layers,)}
16+
c = GNNChain(GraphConv(3 => 5, tanh), GCNConv(5 => 3))
1717
test_lux_layer(rng, c, g, x, outputsize=(3,), container=true)
1818
end
1919
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
end
9090

9191
@testset "GINConv" begin
92-
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
92+
nn = Chain(Dense(in_dims => out_dims, tanh), Dense(out_dims => out_dims))
9393
l = GINConv(nn, 0.5)
9494
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
9595
end

0 commit comments

Comments
 (0)