Skip to content

Commit 2883214

Browse files
better inference of num_nodes
1 parent 58b2316 commit 2883214

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

GNNGraphs/src/gnngraph.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ function GNNGraph(data::D;
127127
@assert graph_type [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
128128
@assert dir [:in, :out]
129129

130+
if ndata !== nothing && num_nodes === nothing
131+
# Infer num_nodes from ndata
132+
# Should be more robust than inferring from data
133+
num_nodes = numobs(ndata)
134+
end
135+
130136
if graph_type == :coo
131137
graph, num_nodes, num_edges = to_coo(data; num_nodes, dir)
132138
elseif graph_type == :dense
@@ -151,7 +157,16 @@ function GNNGraph(data::D;
151157
ndata, edata, gdata)
152158
end
153159

154-
GNNGraph(; kws...) = GNNGraph(0; kws...)
160+
function GNNGraph(; num_nodes = nothing, ndata = nothing, kws...)
161+
if num_nodes === nothing
162+
if ndata === nothing
163+
num_nodes = 0
164+
else
165+
num_nodes = numobs(ndata)
166+
end
167+
end
168+
return GNNGraph(num_nodes; ndata, kws...)
169+
end
155170

156171
function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer}
157172
s, t = T[], T[]

0 commit comments

Comments
 (0)