Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions GNNGraphs/src/gnngraph.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#===================================
#============================================
Define GNNGraph type as a subtype of Graphs.AbstractGraph.
For the core methods to be implemented by any AbstractGraph, see
https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type
Expand Down Expand Up @@ -103,7 +103,6 @@ source, target = edge_index(g)
```
A `GNNGraph` can be sent to the GPU, for example by using Flux.jl's `gpu` function
or MLDataDevices.jl's utilities.
```
"""
struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
graph::T
Expand All @@ -127,6 +126,12 @@ function GNNGraph(data::D;
@assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
@assert dir ∈ [:in, :out]

if num_nodes === nothing && ndata !== nothing
# Infer num_nodes from ndata
# Should be more robust than inferring from data
num_nodes = numobs(ndata)
end

if graph_type == :coo
graph, num_nodes, num_edges = to_coo(data; num_nodes, dir)
elseif graph_type == :dense
Expand All @@ -151,7 +156,16 @@ function GNNGraph(data::D;
ndata, edata, gdata)
end

GNNGraph(; kws...) = GNNGraph(0; kws...)
function GNNGraph(; num_nodes=nothing, ndata=nothing, kws...)
if num_nodes === nothing
if ndata === nothing
num_nodes = 0
else
num_nodes = numobs(ndata)
end
end
return GNNGraph(num_nodes; ndata, kws...)
end

function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer}
s, t = T[], T[]
Expand All @@ -162,7 +176,7 @@ Base.zero(::Type{G}) where {G <: GNNGraph} = G(0)

# COO convenience constructors
function GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...)
GNNGraph((s, t, v); kws...)
return GNNGraph((s, t, v); kws...)
end
GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)

Expand Down
18 changes: 18 additions & 0 deletions GNNGraphs/test/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,24 @@ end
end
end

@testitem "Constructor: empty" setup=[GraphsTestModule] begin
g = GNNGraph(ndata=ones(2, 1))
@test g.num_nodes == 1
@test g.num_edges == 0
@test g.ndata.x == ones(2, 1)

g = GNNGraph(num_nodes=1)
@test g.num_nodes == 1
@test g.num_edges == 0
@test isempty(g.ndata)

g = GNNGraph((Int[], Int[]); ndata=(; a=[1]))
@test g.num_nodes == 1
@test g.num_edges == 0
@test g.ndata.a == [1]
end


@testitem "Features" setup=[GraphsTestModule] begin
using .GraphsTestModule
for GRAPH_T in GRAPH_TYPES
Expand Down