Skip to content

Commit 2b303b8

Browse files
fix bug in LightGraphs constructor
1 parent 3688442 commit 2b303b8

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/gnngraph.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
151151
function GNNGraph(g::AbstractGraph; kws...)
152152
s = LightGraphs.src.(LightGraphs.edges(g))
153153
t = LightGraphs.dst.(LightGraphs.edges(g))
154+
if !LightGraphs.is_directed(g)
155+
# add reverse edges since GNNGraph are directed
156+
s, t = [s; t], [t; s]
157+
end
154158
GNNGraph((s, t); num_nodes = LightGraphs.nv(g), kws...)
155159
end
156160

test/gnngraph.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@
8484
end
8585
end
8686

87+
@testset "LightGraphs constructor" begin
88+
lg = random_regular_graph(10, 4)
89+
@test !LightGraphs.is_directed(lg)
90+
g = GNNGraph(lg)
91+
@test g.num_edges == 2*ne(lg) # g in undirected
92+
@test LightGraphs.is_directed(g)
93+
for e in LightGraphs.edges(lg)
94+
i, j = src(e), dst(e)
95+
@test has_edge(g, i, j)
96+
@test has_edge(g, j, i)
97+
end
98+
end
99+
87100
@testset "add self-loops" begin
88101
A = [1 1 0 0
89102
0 0 1 0
@@ -174,9 +187,9 @@
174187
@testset "LearnBase and DataLoader compat" begin
175188
n, m, num_graphs = 10, 30, 50
176189
X = rand(10, n)
177-
E = rand(10, m)
190+
E = rand(10, 2m)
178191
U = rand(10, 1)
179-
g = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(10, n), edata=rand(10, m), gdata=rand(10, 1))
192+
g = Flux.batch([GNNGraph(erdos_renyi(n, m), ndata=X, edata=E, gdata=U)
180193
for _ in 1:num_graphs])
181194

182195
@test LearnBase.getobs(g, 3) == getgraph(g, 3)[1]

0 commit comments

Comments
 (0)