Skip to content

Commit cb2ad5f

Browse files
better inference for num_nodes (#588)
1 parent dc2c44a commit cb2ad5f

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

GNNGraphs/src/gnngraph.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#===================================
1+
#============================================
22
Define GNNGraph type as a subtype of Graphs.AbstractGraph.
33
For the core methods to be implemented by any AbstractGraph, see
44
https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type
@@ -103,7 +103,6 @@ source, target = edge_index(g)
103103
```
104104
A `GNNGraph` can be sent to the GPU, for example by using Flux.jl's `gpu` function
105105
or MLDataDevices.jl's utilities.
106-
```
107106
"""
108107
struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
109108
graph::T
@@ -127,6 +126,12 @@ function GNNGraph(data::D;
127126
@assert graph_type [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
128127
@assert dir [:in, :out]
129128

129+
if num_nodes === nothing && ndata !== nothing
130+
# Infer num_nodes from ndata
131+
# Should be more robust than inferring from data
132+
num_nodes = numobs(ndata)
133+
end
134+
130135
if graph_type == :coo
131136
graph, num_nodes, num_edges = to_coo(data; num_nodes, dir)
132137
elseif graph_type == :dense
@@ -151,7 +156,16 @@ function GNNGraph(data::D;
151156
ndata, edata, gdata)
152157
end
153158

154-
GNNGraph(; kws...) = GNNGraph(0; kws...)
159+
function GNNGraph(; num_nodes=nothing, ndata=nothing, kws...)
160+
if num_nodes === nothing
161+
if ndata === nothing
162+
num_nodes = 0
163+
else
164+
num_nodes = numobs(ndata)
165+
end
166+
end
167+
return GNNGraph(num_nodes; ndata, kws...)
168+
end
155169

156170
function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer}
157171
s, t = T[], T[]
@@ -162,7 +176,7 @@ Base.zero(::Type{G}) where {G <: GNNGraph} = G(0)
162176

163177
# COO convenience constructors
164178
function GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...)
165-
GNNGraph((s, t, v); kws...)
179+
return GNNGraph((s, t, v); kws...)
166180
end
167181
GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
168182

GNNGraphs/test/gnngraph.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,24 @@ end
213213
end
214214
end
215215

216+
@testitem "Constructor: empty" setup=[GraphsTestModule] begin
217+
g = GNNGraph(ndata=ones(2, 1))
218+
@test g.num_nodes == 1
219+
@test g.num_edges == 0
220+
@test g.ndata.x == ones(2, 1)
221+
222+
g = GNNGraph(num_nodes=1)
223+
@test g.num_nodes == 1
224+
@test g.num_edges == 0
225+
@test isempty(g.ndata)
226+
227+
g = GNNGraph((Int[], Int[]); ndata=(; a=[1]))
228+
@test g.num_nodes == 1
229+
@test g.num_edges == 0
230+
@test g.ndata.a == [1]
231+
end
232+
233+
216234
@testitem "Features" setup=[GraphsTestModule] begin
217235
using .GraphsTestModule
218236
for GRAPH_T in GRAPH_TYPES

0 commit comments

Comments
 (0)