Skip to content

Commit df5b508

Browse files
authored
Typecast GNNGraph.num_nodes to Int (#109)
1 parent ce0396b commit df5b508

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

src/GNNGraphs/convert.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
function to_coo(coo::COO_T; dir=:out, num_nodes=nothing, weighted=true)
44
s, t, val = coo
5-
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
5+
num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
66
@assert isnothing(val) || length(val) == length(s)
77
@assert length(s) == length(t)
88
if !isempty(s)
@@ -114,7 +114,7 @@ function to_dense(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=t
114114
# `dir` will be ignored since the input `coo` is always in source -> target format.
115115
# The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j)
116116
s, t, val = coo
117-
n = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
117+
n::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
118118
val = isnothing(val) ? eltype(s)(1) : val
119119
T = T === nothing ? eltype(val) : T
120120
if !weighted
@@ -164,9 +164,9 @@ function to_sparse(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=
164164
eweight = fill!(similar(s, T), 1)
165165
end
166166

167-
num_nodes = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
167+
num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
168168
A = sparse(s, t, eweight, num_nodes, num_nodes)
169-
num_edges = nnz(A)
169+
num_edges::Int = nnz(A)
170170
if eltype(A) != T
171171
A = T.(A)
172172
end

src/GNNGraphs/gnngraph.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ function GNNGraph(g::AbstractGraph; kws...)
168168
# add reverse edges since GNNGraph is directed
169169
s, t = [s; t], [t; s]
170170
end
171-
GNNGraph((s, t); num_nodes=Graphs.nv(g), kws...)
171+
num_nodes::Int = Graphs.nv(g)
172+
GNNGraph((s, t); num_nodes=num_nodes, kws...)
172173
end
173174

174175

test/GNNGraphs/gnngraph.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,12 @@
177177
for e in Graphs.edges(lg)
178178
i, j = src(e), dst(e)
179179
@test has_edge(g, i, j)
180-
@test has_edge(g, j, i)
180+
@test has_edge(g, j, i)
181+
end
182+
183+
@testset "SimpleGraph{Int32}" begin
184+
g = GNNGraph(SimpleGraph{Int32}(6), graph_type=GRAPH_T)
185+
@test g.num_nodes == 6
181186
end
182187
end
183188

0 commit comments

Comments
 (0)