Skip to content

Commit d781330

Browse files
fix GraphNeuralNetworks.jl
1 parent 50d51ef commit d781330

File tree

4 files changed

+38
-36
lines changed

4 files changed

+38
-36
lines changed

GraphNeuralNetworks/docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using GraphNeuralNetworks
1212
using Flux, GNNGraphs, GNNlib, Graphs
1313
using DocumenterInterLinks
1414

15+
DocMeta.setdocmeta!(GNNGraphs, :DocTestSetup, :(using GNNGraphs, MLUtils); recursive = true)
1516
DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, :(using GraphNeuralNetworks); recursive = true)
1617

1718
mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/require", "[tex]/mathtools"]),

GraphNeuralNetworks/src/layers/basic.jl

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,30 +74,22 @@ julia> using Flux, GraphNeuralNetworks
7474
julia> m = GNNChain(GCNConv(2=>5),
7575
BatchNorm(5),
7676
x -> relu.(x),
77-
Dense(5, 4))
78-
GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4))
77+
Dense(5, 4));
7978
8079
julia> x = randn(Float32, 2, 3);
8180
8281
julia> g = rand_graph(3, 6)
8382
GNNGraph:
84-
num_nodes = 3
85-
num_edges = 6
83+
num_nodes: 3
84+
num_edges: 6
8685
87-
julia> m(g, x)
88-
4×3 Matrix{Float32}:
89-
-0.795592 -0.795592 -0.795592
90-
-0.736409 -0.736409 -0.736409
91-
0.994925 0.994925 0.994925
92-
0.857549 0.857549 0.857549
86+
julia> m(g, x) |> size
87+
(4, 3)
9388
94-
julia> m2 = GNNChain(enc = m,
95-
dec = DotDecoder())
96-
GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder())
89+
julia> m2 = GNNChain(enc = m, dec = DotDecoder());
9790
98-
julia> m2(g, x)
99-
1×6 Matrix{Float32}:
100-
2.90053 2.90053 2.90053 2.90053 2.90053 2.90053
91+
julia> m2(g, x) |> size
92+
(1, 6)
10193
10294
julia> m2[:enc](g, x) == m(g, x)
10395
true
@@ -196,15 +188,14 @@ returns the dot product `x_i ⋅ xj` on each edge.
196188
```jldoctest
197189
julia> g = rand_graph(5, 6)
198190
GNNGraph:
199-
num_nodes = 5
200-
num_edges = 6
191+
num_nodes: 5
192+
num_edges: 6
201193
202194
julia> dotdec = DotDecoder()
203195
DotDecoder()
204196
205-
julia> dotdec(g, rand(2, 5))
206-
1×6 Matrix{Float64}:
207-
0.345098 0.458305 0.106353 0.345098 0.458305 0.106353
197+
julia> dotdec(g, rand(2, 5)) |> size
198+
(1, 6)
208199
```
209200
"""
210201
struct DotDecoder <: GNNLayer end

GraphNeuralNetworks/src/layers/heteroconv.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ have to be aggregated using the `aggr` function. The default is to sum the outpu
2121
# Examples
2222
2323
```jldoctest
24-
julia> g = rand_bipartite_heterograph((10, 15), 20)
24+
julia> using GraphNeuralNetworks, Flux
25+
26+
julia> g = rand_bipartite_heterograph((10, 15), 80)
2527
GNNHeteroGraph:
2628
num_nodes: Dict(:A => 10, :B => 15)
27-
num_edges: Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20)
29+
num_edges: Dict((:A, :to, :B) => 80, (:B, :to, :A) => 80)
2830
2931
julia> x = (A = rand(Float32, 64, 10), B = rand(Float32, 64, 15));
3032

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,13 @@ Returns the updated node features:
5959
The following example considers a static graph and a time-varying node features.
6060
6161
```jldoctest
62-
julia> num_nodes, num_edges = 5, 10;
62+
julia> num_nodes, num_edges = 5, 16;
6363
6464
julia> d_in, d_out = 2, 3;
6565
6666
julia> timesteps = 5;
6767
6868
julia> g = rand_graph(num_nodes, num_edges);
69-
GNNGraph:
70-
num_nodes: 5
71-
num_edges: 10
7269
7370
julia> x = rand(Float32, d_in, timesteps, num_nodes);
7471
@@ -93,11 +90,15 @@ julia> timesteps = 5;
9390
9491
julia> num_nodes = [10, 10, 10, 10, 10];
9592
96-
julia> num_edges = [10, 12, 14, 16, 18];
93+
julia> num_edges = [20, 22, 24, 26, 28];
9794
9895
julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)];
9996
10097
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
98+
TemporalSnapshotsGNNGraph:
99+
num_nodes: [10, 10, 10, 10, 10]
100+
num_edges: [20, 22, 24, 26, 28]
101+
num_snapshots: 5
101102
102103
julia> x = [rand(Float32, d_in, n) for n in num_nodes];
103104
@@ -269,7 +270,7 @@ See [`GNNRecurrence`](@ref) for more details.
269270
# Examples
270271
271272
```jldoctest
272-
julia> num_nodes, num_edges = 5, 10;
273+
julia> num_nodes, num_edges = 5, 16;
273274
274275
julia> d_in, d_out = 2, 3;
275276
@@ -280,7 +281,7 @@ julia> g = rand_graph(num_nodes, num_edges);
280281
julia> x = rand(Float32, d_in, timesteps, num_nodes);
281282
282283
julia> layer = GConvGRU(d_in => d_out, 2)
283-
GConvGRU(
284+
GNNRecurrence(
284285
GConvGRUCell(2 => 3, 2), # 108 parameters
285286
) # Total: 12 arrays, 108 parameters, 1.148 KiB.
286287
@@ -326,9 +327,9 @@ where `output` is the updated hidden state `h` of the LSTM cell and `state` is t
326327
# Examples
327328
328329
```jldoctest
329-
julia> using GraphNeuralNetworks, Flux
330+
julia> using Flux
330331
331-
julia> num_nodes, num_edges = 5, 10;
332+
julia> num_nodes, num_edges = 5, 16;
332333
333334
julia> d_in, d_out = 2, 3;
334335
@@ -453,7 +454,7 @@ See [`GNNRecurrence`](@ref) for more details.
453454
# Examples
454455
455456
```jldoctest
456-
julia> num_nodes, num_edges = 5, 10;
457+
julia> num_nodes, num_edges = 5, 16;
457458
458459
julia> d_in, d_out = 2, 3;
459460
@@ -727,23 +728,27 @@ julia> timesteps = 5;
727728
728729
julia> num_nodes = [10, 10, 10, 10, 10];
729730
730-
julia> num_edges = [10, 12, 14, 16, 18];
731+
julia> num_edges = [60, 62, 64, 66, 68];
731732
732733
julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)];
733734
734735
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
736+
TemporalSnapshotsGNNGraph:
737+
num_nodes: [10, 10, 10, 10, 10]
738+
num_edges: [60, 62, 64, 66, 68]
739+
num_snapshots: 5
735740
736741
julia> x = [rand(Float32, d_in, n) for n in num_nodes];
737742
738-
julia> cell = EvolveGCNO(d_in => d_out)
743+
julia> layer = EvolveGCNO(d_in => d_out)
739744
GNNRecurrence(
740745
EvolveGCNOCell(2 => 3), # 321 parameters
741746
) # Total: 5 arrays, 321 parameters, 1.535 KiB.
742747
743748
julia> y = layer(tg, x);
744749
745750
julia> length(y) # timesteps
746-
5
751+
5
747752
748753
julia> size(y[end]) # (d_out, num_nodes[end])
749754
(3, 10)
@@ -874,6 +879,9 @@ julia> g = rand_graph(num_nodes, num_edges);
874879
julia> x = rand(Float32, d_in, timesteps, num_nodes);
875880
876881
julia> layer = TGCN(d_in => d_out)
882+
GNNRecurrence(
883+
TGCNCell(2 => 3), # 126 parameters
884+
) # Total: 18 arrays, 126 parameters, 1.469 KiB.
877885
878886
julia> y = layer(g, x);
879887

0 commit comments

Comments
 (0)