Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions GNNGraphs/src/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ function GNNGraph(data::D;
@assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
@assert dir ∈ [:in, :out]

if ndata !== nothing && num_nodes === nothing
# Infer num_nodes from ndata
# Should be more robust than inferring from data
num_nodes = numobs(ndata)
end

if graph_type == :coo
graph, num_nodes, num_edges = to_coo(data; num_nodes, dir)
elseif graph_type == :dense
Expand All @@ -143,15 +149,24 @@ function GNNGraph(data::D;

# don't force the shape of the data when there is only one graph
gdata = normalize_graphdata(gdata, default_name = :u,
n = num_graphs > 1 ? num_graphs : -1)
n = num_graphs > 1 ? num_graphs : -1, glob=true)

GNNGraph(graph,
num_nodes, num_edges, num_graphs,
graph_indicator,
ndata, edata, gdata)
end

GNNGraph(; kws...) = GNNGraph(0; kws...)
function GNNGraph(; num_nodes = nothing, ndata = nothing, kws...)
if num_nodes === nothing
if ndata === nothing
num_nodes = 0
else
num_nodes = numobs(ndata)
end
end
return GNNGraph(num_nodes; ndata, kws...)
end

function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer}
s, t = T[], T[]
Expand Down Expand Up @@ -186,10 +201,10 @@ end

function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata,
graph_type = nothing)
ndata = normalize_graphdata(ndata, default_name = :x, n = g.num_nodes)
edata = normalize_graphdata(edata, default_name = :e, n = g.num_edges,
duplicate_if_needed = true)
gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs)
ndata = normalize_graphdata(ndata, default_name=:x, n=g.num_nodes)
edata = normalize_graphdata(edata, default_name=:e, n=g.num_edges,
duplicate_if_needed=true)
gdata = normalize_graphdata(gdata, default_name=:u, n=g.num_graphs, glob=true)

if !isnothing(graph_type)
if graph_type == :coo
Expand Down
2 changes: 1 addition & 1 deletion GNNGraphs/src/gnnheterograph/gnnheterograph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ function GNNHeteroGraph(data::EDict;
edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges,
duplicate_if_needed = true)
gdata = normalize_graphdata(gdata, default_name = :u,
n = num_graphs > 1 ? num_graphs : -1)
n = num_graphs > 1 ? num_graphs : -1, glob=true)
end

return GNNHeteroGraph(graph,
Expand Down
6 changes: 3 additions & 3 deletions GNNGraphs/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,16 @@ function normalize_graphdata(data; default_name::Symbol, kws...)
normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...)
end

function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false)
function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false, glob=false)
# This had to workaround two Zygote bugs with NamedTuples
# https://github.com/FluxML/Zygote.jl/issues/1071
# https://github.com/FluxML/Zygote.jl/issues/1072
# https://github.com/FluxML/Zygote.jl/issues/1072 # TODO this is fixed

if n > 1
@assert all(x -> x isa AbstractArray, data) "Non-array features provided."
end

if n <= 1
if n <= 1 && glob
# If last array dimension is not 1, add a new dimension.
# This is mostly useful to reshape global feature vectors
# of size D to Dx1 matrices.
Expand Down
28 changes: 28 additions & 0 deletions GNNGraphs/test/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ end
end
end

@testitem "Constructor: empty" setup=[GraphsTestModule] begin
g = GNNGraph(ndata=ones(2, 1))
@test g.num_nodes == 1
@test g.num_edges == 0
@test g.ndata.x == ones(2, 1)

g = GNNGraph(num_nodes=1)
@test g.num_nodes == 1
@test g.num_edges == 0
@test isempty(g.ndata)

g = GNNGraph((Int[], Int[]); ndata=(;a=[1]))
@test g.num_nodes == 1
@test g.num_edges == 0
@test g.ndata.a == [1]

g = GNNGraph((Int[], Int[]); ndata=(;a=[1]), edata=(;b=Int[]), num_nodes=1)
@test g.num_nodes == 1
@test g.num_edges == 0
@test g.ndata.a == [1]
@test g.edata.b == Int[]

g = GNNGraph(; edata=(;b=Int[]))
@test g.num_nodes == 0
@test g.num_edges == 0
@test g.edata.b == Int[]
end

@testitem "symmetric graph" setup=[GraphsTestModule] tags=[:gpu] begin
using .GraphsTestModule
dev = gpu_device()
Expand Down
Loading