@@ -31,9 +31,14 @@ julia> using Lux, GNNLux, Random
3131
3232julia> rng = Random.default_rng();
3333
34- julia> m = GNNChain(GCNConv(2=>5),
35- x -> relu.(x),
36- Dense(5=>4))
34+ julia> m = GNNChain(GCNConv(2 => 5, relu), Dense(5 => 4))
35+ GNNChain(
36+ layers = NamedTuple(
37+ layer_1 = GCNConv(2 => 5, relu), # 15 parameters
38+ layer_2 = Dense(5 => 4), # 24 parameters
39+ ),
40+ ) # Total: 39 parameters,
41+ # plus 0 states.
3742
3843julia> x = randn(rng, Float32, 2, 3);
3944
@@ -44,8 +49,10 @@ GNNGraph:
4449
4550julia> ps, st = LuxCore.setup(rng, m);
4651
47- julia> m(g, x, ps, st) # First entry is the output, second entry is the state of the model
48- (Float32[-0.15594329 -0.15594329 -0.15594329; 0.93431795 0.93431795 0.93431795; 0.27568763 0.27568763 0.27568763; 0.12568939 0.12568939 0.12568939], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
52+ julia> y, st = m(g, x, ps, st); # First entry is the output, second entry is the state of the model
53+
54+ julia> size(y)
55+ (4, 3)
4956```
5057"""
5158@concrete struct GNNChain <: GNNContainerLayer{(:layers,)}
0 commit comments