1- #= ==================================
1+ #= ===========================================
22Define GNNGraph type as a subtype of Graphs.AbstractGraph.
33For the core methods to be implemented by any AbstractGraph, see
44https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type
@@ -103,7 +103,6 @@ source, target = edge_index(g)
103103```
104104A `GNNGraph` can be sent to the GPU, for example by using Flux.jl's `gpu` function
105105or MLDataDevices.jl's utilities.
106- ```
107106"""
108107struct GNNGraph{T <: Union{COO_T, ADJMAT_T} } <: AbstractGNNGraph{T}
109108 graph:: T
@@ -127,6 +126,12 @@ function GNNGraph(data::D;
127126 @assert graph_type ∈ [:coo , :dense , :sparse ] " Invalid graph_type $graph_type requested"
128127 @assert dir ∈ [:in , :out ]
129128
129+ if num_nodes === nothing && ndata != = nothing
130+ # Infer num_nodes from ndata
131+ # Should be more robust than inferring from data
132+ num_nodes = numobs (ndata)
133+ end
134+
130135 if graph_type == :coo
131136 graph, num_nodes, num_edges = to_coo (data; num_nodes, dir)
132137 elseif graph_type == :dense
@@ -151,7 +156,16 @@ function GNNGraph(data::D;
151156 ndata, edata, gdata)
152157end
153158
154- GNNGraph (; kws... ) = GNNGraph (0 ; kws... )
159+ function GNNGraph (; num_nodes= nothing , ndata= nothing , kws... )
160+ if num_nodes === nothing
161+ if ndata === nothing
162+ num_nodes = 0
163+ else
164+ num_nodes = numobs (ndata)
165+ end
166+ end
167+ return GNNGraph (num_nodes; ndata, kws... )
168+ end
155169
156170function (:: Type{<:GNNGraph} )(num_nodes:: T ; kws... ) where {T <: Integer }
157171 s, t = T[], T[]
@@ -162,7 +176,7 @@ Base.zero(::Type{G}) where {G <: GNNGraph} = G(0)
162176
163177# COO convenience constructors
164178function GNNGraph (s:: AbstractVector , t:: AbstractVector , v = nothing ; kws... )
165- GNNGraph ((s, t, v); kws... )
179+ return GNNGraph ((s, t, v); kws... )
166180end
167181GNNGraph ((s, t):: NTuple{2} ; kws... ) = GNNGraph ((s, t, nothing ); kws... )
168182
0 commit comments