Skip to content

Commit e2b3f2c

Browse files
fix
1 parent b9e8571 commit e2b3f2c

File tree

4 files changed

+31
-11
lines changed

4 files changed

+31
-11
lines changed

src/gnngraph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,14 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
149149

150150
# GNNGraph(g::AbstractGraph; kws...) = GNNGraph(adjacency_matrix(g, dir=:out); kws...)
151151

152-
function GNNGraph(g::AbstractGraph; edata=(;), kws...)
152+
function GNNGraph(g::AbstractGraph; kws...)
153153
s = LightGraphs.src.(LightGraphs.edges(g))
154154
t = LightGraphs.dst.(LightGraphs.edges(g))
155155
if !LightGraphs.is_directed(g)
156156
# add reverse edges since GNNGraph are directed
157157
s, t = [s; t], [t; s]
158158
end
159-
GNNGraph((s, t); edata, num_nodes=LightGraphs.nv(g), kws...)
159+
GNNGraph((s, t); num_nodes=LightGraphs.nv(g), kws...)
160160
end
161161

162162

src/utils.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,41 @@ normalize_graphdata(data::Nothing; kws...) = NamedTuple()
2929
normalize_graphdata(data; default_name::Symbol, kws...) =
3030
normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...)
3131

32-
function normalize_graphdata(data::NamedTuple; default_name=:z, n, duplicate_if_needed=false)
32+
function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed=false)
33+
# This had to workaround two Zygote bugs with NamedTuples
34+
# https://github.com/FluxML/Zygote.jl/issues/1071
35+
# https://github.com/FluxML/Zygote.jl/issues/1072
36+
37+
if n == 1
38+
# If last array dimension is not 1, add a new dimension.
39+
# This is mostly usefule to reshape globale feature vectors
40+
# of size D to Dx1 matrices.
41+
function unsqz(v)
42+
if v isa AbstractArray && size(v)[end] != 1
43+
v = reshape(v, size(v)..., 1)
44+
end
45+
v
46+
end
47+
48+
data = NamedTuple{keys(data)}(unsqz.(values(data)))
49+
end
50+
3351
sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)
3452

35-
if duplicate_if_needed # used to copy edge features on reverse edges
53+
if duplicate_if_needed
54+
# Used to copy edge features on reverse edges
3655
@assert all(s -> s == 0 || s == n || s == n÷2, sz)
3756

38-
function replace(k, v)
57+
function duplicate(v)
3958
if v isa AbstractArray && size(v)[end] == n÷2
4059
v = cat(v, v, dims=ndims(v))
4160
end
42-
k => v
61+
v
4362
end
4463

45-
data = NamedTuple(replace(k,v) for (k,v) in pairs(data))
64+
data = NamedTuple{keys(data)}(duplicate.(values(data)))
4665
else
4766
@assert all(s -> s == 0 || s == n, sz)
4867
end
4968
return data
5069
end
51-

test/layers/conv.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
for g in test_graphs
3535
test_layer(l, g, rtol=1e-5)
3636
end
37-
end
3837

38+
l = GCNConv(in_channel => out_channel, add_self_loops=false)
39+
test_layer(l, g1, rtol=1e-5)
40+
end
3941

4042
@testset "ChebConv" begin
4143
k = 6

test/msgpass.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ import GraphNeuralNetworks: compute_message, update_node, update_edge, propagate
101101
@test all(adjacency_matrix(g_) .== adj)
102102
@test size(node_features(g_)) == (2*out_channel, num_V)
103103
@test size(edge_features(g_)) == (out_channel, num_E)
104-
@test size(graph_features(g_)) == (in_channel,)
104+
@test size(graph_features(g_)) == (in_channel, 1)
105105
end
106106

107107
@testset "message and update with weights" begin
@@ -124,7 +124,7 @@ import GraphNeuralNetworks: compute_message, update_node, update_edge, propagate
124124
@test adjacency_matrix(g_) == adj
125125
@test size(node_features(g_)) == (out_channel, num_V)
126126
@test edge_features(g_) === E
127-
@test graph_features(g_) === U
127+
@test vec(graph_features(g_)) U
128128
end
129129

130130
@testset "NamedTuples" begin

0 commit comments

Comments
 (0)