Skip to content

Commit 87c062b

Browse files
fix tests
1 parent 1f80cb6 commit 87c062b

File tree

4 files changed

+29
-27
lines changed

4 files changed

+29
-27
lines changed

GNNGraphs/src/convert.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,24 @@ function to_coo(data::EDict; num_nodes = nothing, kws...)
44
graph = EDict{COO_T}()
55
_num_nodes = NDict{Int}()
66
num_edges = EDict{Int}()
7-
if !isempty(data)
8-
for k in keys(data)
9-
d = data[k]
10-
@assert d isa Tuple
11-
if length(d) == 2
12-
d = (d..., nothing)
13-
end
14-
if num_nodes !== nothing
15-
n1 = get(num_nodes, k[1], nothing)
16-
n2 = get(num_nodes, k[3], nothing)
17-
else
18-
n1 = nothing
19-
n2 = nothing
20-
end
21-
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
22-
graph[k] = g
23-
num_edges[k] = nedges
24-
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
25-
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
7+
for k in keys(data)
8+
d = data[k]
9+
@assert d isa Tuple
10+
if length(d) == 2
11+
d = (d..., nothing)
2612
end
27-
graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types
13+
if num_nodes !== nothing
14+
n1 = get(num_nodes, k[1], nothing)
15+
n2 = get(num_nodes, k[3], nothing)
16+
else
17+
n1 = nothing
18+
n2 = nothing
19+
end
20+
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
21+
graph[k] = g
22+
num_edges[k] = nedges
23+
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
24+
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
2825
end
2926
return graph, _num_nodes, num_edges
3027
end

GNNGraphs/src/generate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; bidirected = true,
5959
else
6060
s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false)
6161
end
62-
return GNNGraph((s, t, edge_weight); kws...)
62+
return GNNGraph((s, t, edge_weight); num_nodes=n, kws...)
6363
end
6464

6565
"""

GNNGraphs/src/transform.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ then all new self loops will have no weight.
5757
If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same.
5858
This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type.
5959
"""
60-
function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V}
60+
function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
61+
6162
function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
6263
get(g.graph, edge_t, (nothing, nothing, nothing))[3]
6364
end
@@ -69,13 +70,17 @@ function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where
6970
n = get(g.num_nodes, src_t, 0)
7071

7172
if haskey(g.graph, edge_t)
72-
x = g.graph[edge_t]
73-
s, t = x[1:2]
73+
s, t = g.graph[edge_t][1:2]
7474
nodes = convert(typeof(s), [1:n;])
7575
s = [s; nodes]
7676
t = [t; nodes]
7777
else
78-
nodes = convert(T, [1:n;])
78+
if !isempty(g.graph)
79+
T = typeof(first(values(g.graph))[1])
80+
nodes = convert(T, [1:n;])
81+
else
82+
nodes = [1:n;]
83+
end
7984
s = nodes
8085
t = nodes
8186
end

GNNGraphs/test/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
end
9494
end
9595

96-
@testset "color_refinment" begin
96+
@testset "color_refinement" begin
9797
rng = MersenneTwister(17)
9898
g = rand_graph(rng, 10, 20, graph_type = GRAPH_T)
9999
x0 = ones(Int, 10)
@@ -104,4 +104,4 @@ end
104104

105105
x2, _, _ = color_refinement(g)
106106
@test x2 == x
107-
end
107+
end

0 commit comments

Comments
 (0)