diff --git a/GNNGraphs/src/gnngraph.jl b/GNNGraphs/src/gnngraph.jl index d7391f3c8..1c826d748 100644 --- a/GNNGraphs/src/gnngraph.jl +++ b/GNNGraphs/src/gnngraph.jl @@ -1,4 +1,4 @@ -#=================================== +#============================================ Define GNNGraph type as a subtype of Graphs.AbstractGraph. For the core methods to be implemented by any AbstractGraph, see https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type @@ -103,7 +103,6 @@ source, target = edge_index(g) ``` A `GNNGraph` can be sent to the GPU, for example by using Flux.jl's `gpu` function or MLDataDevices.jl's utilities. -``` """ struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T} graph::T @@ -127,6 +126,12 @@ function GNNGraph(data::D; @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" @assert dir ∈ [:in, :out] + if num_nodes === nothing && ndata !== nothing + # Infer num_nodes from ndata + # Should be more robust than inferring from data + num_nodes = numobs(ndata) + end + if graph_type == :coo graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) elseif graph_type == :dense @@ -151,7 +156,16 @@ function GNNGraph(data::D; ndata, edata, gdata) end -GNNGraph(; kws...) = GNNGraph(0; kws...) +function GNNGraph(; num_nodes=nothing, ndata=nothing, kws...) + if num_nodes === nothing + if ndata === nothing + num_nodes = 0 + else + num_nodes = numobs(ndata) + end + end + return GNNGraph(num_nodes; ndata, kws...) +end function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer} s, t = T[], T[] @@ -162,7 +176,7 @@ Base.zero(::Type{G}) where {G <: GNNGraph} = G(0) # COO convenience constructors function GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) - GNNGraph((s, t, v); kws...) + return GNNGraph((s, t, v); kws...) end GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...) diff --git a/GNNGraphs/test/gnngraph.jl b/GNNGraphs/test/gnngraph.jl index 1e78349dc..252d2f5c0 100644 --- a/GNNGraphs/test/gnngraph.jl +++ b/GNNGraphs/test/gnngraph.jl @@ -213,6 +213,24 @@ end end end +@testitem "Constructor: empty" setup=[GraphsTestModule] begin + g = GNNGraph(ndata=ones(2, 1)) + @test g.num_nodes == 1 + @test g.num_edges == 0 + @test g.ndata.x == ones(2, 1) + + g = GNNGraph(num_nodes=1) + @test g.num_nodes == 1 + @test g.num_edges == 0 + @test isempty(g.ndata) + + g = GNNGraph((Int[], Int[]); ndata=(; a=[1])) + @test g.num_nodes == 1 + @test g.num_edges == 0 + @test g.ndata.a == [1] +end + + @testitem "Features" setup=[GraphsTestModule] begin using .GraphsTestModule for GRAPH_T in GRAPH_TYPES