diff --git a/GNNGraphs/src/gnngraph.jl b/GNNGraphs/src/gnngraph.jl index d7391f3c8..2a85d9dc8 100644 --- a/GNNGraphs/src/gnngraph.jl +++ b/GNNGraphs/src/gnngraph.jl @@ -127,6 +127,12 @@ function GNNGraph(data::D; @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" @assert dir ∈ [:in, :out] + if ndata !== nothing && num_nodes === nothing + # Infer num_nodes from ndata + # Should be more robust than inferring from data + num_nodes = numobs(ndata) + end + if graph_type == :coo graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) elseif graph_type == :dense @@ -143,7 +149,7 @@ function GNNGraph(data::D; # don't force the shape of the data when there is only one graph gdata = normalize_graphdata(gdata, default_name = :u, - n = num_graphs > 1 ? num_graphs : -1) + n = num_graphs > 1 ? num_graphs : -1, glob=true) GNNGraph(graph, num_nodes, num_edges, num_graphs, @@ -151,7 +157,16 @@ function GNNGraph(data::D; ndata, edata, gdata) end -GNNGraph(; kws...) = GNNGraph(0; kws...) +function GNNGraph(; num_nodes = nothing, ndata = nothing, kws...) + if num_nodes === nothing + if ndata === nothing + num_nodes = 0 + else + num_nodes = numobs(ndata) + end + end + return GNNGraph(num_nodes; ndata, kws...) +end function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer} s, t = T[], T[] @@ -186,10 +201,10 @@ end function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata, graph_type = nothing) - ndata = normalize_graphdata(ndata, default_name = :x, n = g.num_nodes) - edata = normalize_graphdata(edata, default_name = :e, n = g.num_edges, - duplicate_if_needed = true) - gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs) + ndata = normalize_graphdata(ndata, default_name=:x, n=g.num_nodes) + edata = normalize_graphdata(edata, default_name=:e, n=g.num_edges, + duplicate_if_needed=true) + gdata = normalize_graphdata(gdata, default_name=:u, n=g.num_graphs, glob=true) if !isnothing(graph_type) if graph_type == :coo diff --git a/GNNGraphs/src/gnnheterograph/gnnheterograph.jl b/GNNGraphs/src/gnnheterograph/gnnheterograph.jl index 4a3f8b924..f939c92f4 100644 --- a/GNNGraphs/src/gnnheterograph/gnnheterograph.jl +++ b/GNNGraphs/src/gnnheterograph/gnnheterograph.jl @@ -145,7 +145,7 @@ function GNNHeteroGraph(data::EDict; edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges, duplicate_if_needed = true) gdata = normalize_graphdata(gdata, default_name = :u, - n = num_graphs > 1 ? num_graphs : -1) + n = num_graphs > 1 ? num_graphs : -1, glob=true) end return GNNHeteroGraph(graph, diff --git a/GNNGraphs/src/utils.jl b/GNNGraphs/src/utils.jl index a6e96a3ab..64b6d2d73 100644 --- a/GNNGraphs/src/utils.jl +++ b/GNNGraphs/src/utils.jl @@ -129,16 +129,16 @@ function normalize_graphdata(data; default_name::Symbol, kws...) normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...) end -function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false) +function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false, glob=false) # This had to workaround two Zygote bugs with NamedTuples # https://github.com/FluxML/Zygote.jl/issues/1071 - # https://github.com/FluxML/Zygote.jl/issues/1072 + # https://github.com/FluxML/Zygote.jl/issues/1072 # TODO this is fixed if n > 1 @assert all(x -> x isa AbstractArray, data) "Non-array features provided." end - if n <= 1 + if n <= 1 && glob # If last array dimension is not 1, add a new dimension. # This is mostly useful to reshape global feature vectors # of size D to Dx1 matrices. diff --git a/GNNGraphs/test/gnngraph.jl b/GNNGraphs/test/gnngraph.jl index 1e78349dc..e3b49b1ce 100644 --- a/GNNGraphs/test/gnngraph.jl +++ b/GNNGraphs/test/gnngraph.jl @@ -53,6 +53,34 @@ end end end +@testitem "Constructor: empty" setup=[GraphsTestModule] begin + g = GNNGraph(ndata=ones(2, 1)) + @test g.num_nodes == 1 + @test g.num_edges == 0 + @test g.ndata.x == ones(2, 1) + + g = GNNGraph(num_nodes=1) + @test g.num_nodes == 1 + @test g.num_edges == 0 + @test isempty(g.ndata) + + g = GNNGraph((Int[], Int[]); ndata=(;a=[1])) + @test g.num_nodes == 1 + @test g.num_edges == 0 + @test g.ndata.a == [1] + + g = GNNGraph((Int[], Int[]); ndata=(;a=[1]), edata=(;b=Int[]), num_nodes=1) + @test g.num_nodes == 1 + @test g.num_edges == 0 + @test g.ndata.a == [1] + @test g.edata.b == Int[] + + g = GNNGraph(; edata=(;b=Int[])) + @test g.num_nodes == 0 + @test g.num_edges == 0 + @test g.edata.b == Int[] +end + @testitem "symmetric graph" setup=[GraphsTestModule] tags=[:gpu] begin using .GraphsTestModule dev = gpu_device()