Skip to content

Commit 97c75be

Browse files
edges iterates over Graphs.Edge (#171)
* edges iterates over Graphs.Edge * add test for zero * fix test * fix test * fix io
1 parent 1c261d4 commit 97c75be

File tree

5 files changed

+25
-8
lines changed

5 files changed

+25
-8
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1111
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1212
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
13+
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
1314
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1415
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/GNNGraphs/gnngraph.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T<:Integer}
155155
return GNNGraph(s, t; num_nodes, kws...)
156156
end
157157

158+
Base.zero(::Type{G}) where G<:GNNGraph = G(0)
159+
158160
# COO convenience constructors
159161
GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) = GNNGraph((s, t, v); kws...)
160162
GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
@@ -209,7 +211,7 @@ function Base.show(io::IO, ::MIME"text/plain", g::GNNGraph)
209211
print(io, "GNNGraph:
210212
num_nodes = $(g.num_nodes)
211213
num_edges = $(g.num_edges)")
212-
g.num_graphs > 1 && print("\n num_graphs = $(g.num_graphs)")
214+
g.num_graphs > 1 && print(io, "\n num_graphs = $(g.num_graphs)")
213215
if !isempty(g.ndata)
214216
print(io, "\n ndata:")
215217
for k in keys(g.ndata)

src/GNNGraphs/query.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3]
1717

1818
get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes=g.num_nodes)[1][3]
1919

20-
Graphs.edges(g::GNNGraph) = zip(edge_index(g)...)
20+
Graphs.edges(g::GNNGraph) = Graphs.Edge.(edge_index(g)...)
2121

22-
Graphs.edgetype(g::GNNGraph) = Tuple{Int, Int}
22+
Graphs.edgetype(g::GNNGraph) = Graphs.Edge{eltype(g)}
2323

2424
# """
2525
# eltype(g::GNNGraph)
@@ -42,9 +42,9 @@ end
4242

4343
Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i,j] != 0
4444

45-
graph_type_symbol(g::GNNGraph{<:COO_T}) = :coo
46-
graph_type_symbol(g::GNNGraph{<:SPARSE_T}) = :sparse
47-
graph_type_symbol(g::GNNGraph{<:ADJMAT_T}) = :dense
45+
graph_type_symbol(::GNNGraph{<:COO_T}) = :coo
46+
graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse
47+
graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense
4848

4949
Graphs.nv(g::GNNGraph) = g.num_nodes
5050
Graphs.ne(g::GNNGraph) = g.num_edges

test/GNNGraphs/gnngraph.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
@test g.num_nodes == 4
6464
@test nv(g) == g.num_nodes
6565
@test ne(g) == g.num_edges
66-
@test collect(edges(g)) |> sort == collect(zip(s, t)) |> sort
66+
@test Tuple.(collect(edges(g))) |> sort == collect(zip(s, t)) |> sort
6767
@test sort(outneighbors(g, 1)) == [2, 4]
6868
@test sort(inneighbors(g, 1)) == [2, 4]
6969
@test is_directed(g) == true
@@ -150,7 +150,7 @@
150150

151151
@test g.num_edges == 4
152152
@test g.num_nodes == 4
153-
@test collect(edges(g)) |> sort == collect(zip(s, t)) |> sort
153+
@test length(edges(g)) == 4
154154
@test sort(outneighbors(g, 1)) == [2]
155155
@test sort(inneighbors(g, 1)) == [4]
156156
@test is_directed(g) == true
@@ -168,6 +168,12 @@
168168
@test adjacency_list(g, dir=:in) == adj_list_in
169169
end
170170

171+
@testset "zero" begin
172+
g = rand_graph(4, 6, graph_type=GRAPH_T)
173+
G = typeof(g)
174+
@test zero(G) == G(0)
175+
end
176+
171177
@testset "Graphs.jl constructor" begin
172178
lg = random_regular_graph(10, 4)
173179
@test !Graphs.is_directed(lg)

test/GNNGraphs/query.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
end
2222
end
2323

24+
@testset "edges" begin
25+
g = rand_graph(4, 10, graph_type=GRAPH_T)
26+
@test edgetype(g) <: Graphs.Edge
27+
for e in edges(g)
28+
@test e isa Graphs.Edge
29+
end
30+
end
31+
2432
@testset "has_self_loops" begin
2533
s = [1, 1, 2, 3]
2634
t = [2, 2, 2, 4]

0 commit comments

Comments
 (0)