1
- #= ==================================
1
+ #= ===========================================
2
2
Define GNNGraph type as a subtype of Graphs.AbstractGraph.
3
3
For the core methods to be implemented by any AbstractGraph, see
4
4
https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type
@@ -103,7 +103,6 @@ source, target = edge_index(g)
103
103
```
104
104
A `GNNGraph` can be sent to the GPU, for example by using Flux.jl's `gpu` function
105
105
or MLDataDevices.jl's utilities.
106
- ```
107
106
"""
108
107
struct GNNGraph{T <: Union{COO_T, ADJMAT_T} } <: AbstractGNNGraph{T}
109
108
graph:: T
@@ -127,6 +126,12 @@ function GNNGraph(data::D;
127
126
@assert graph_type ∈ [:coo , :dense , :sparse ] " Invalid graph_type $graph_type requested"
128
127
@assert dir ∈ [:in , :out ]
129
128
129
+ if num_nodes === nothing && ndata != = nothing
130
+ # Infer num_nodes from ndata
131
+ # Should be more robust than inferring from data
132
+ num_nodes = numobs (ndata)
133
+ end
134
+
130
135
if graph_type == :coo
131
136
graph, num_nodes, num_edges = to_coo (data; num_nodes, dir)
132
137
elseif graph_type == :dense
@@ -151,7 +156,16 @@ function GNNGraph(data::D;
151
156
ndata, edata, gdata)
152
157
end
153
158
154
- GNNGraph (; kws... ) = GNNGraph (0 ; kws... )
159
+ function GNNGraph (; num_nodes= nothing , ndata= nothing , kws... )
160
+ if num_nodes === nothing
161
+ if ndata === nothing
162
+ num_nodes = 0
163
+ else
164
+ num_nodes = numobs (ndata)
165
+ end
166
+ end
167
+ return GNNGraph (num_nodes; ndata, kws... )
168
+ end
155
169
156
170
function (:: Type{<:GNNGraph} )(num_nodes:: T ; kws... ) where {T <: Integer }
157
171
s, t = T[], T[]
@@ -162,7 +176,7 @@ Base.zero(::Type{G}) where {G <: GNNGraph} = G(0)
162
176
163
177
# COO convenience constructors
164
178
function GNNGraph (s:: AbstractVector , t:: AbstractVector , v = nothing ; kws... )
165
- GNNGraph ((s, t, v); kws... )
179
+ return GNNGraph ((s, t, v); kws... )
166
180
end
167
181
GNNGraph ((s, t):: NTuple{2} ; kws... ) = GNNGraph ((s, t, nothing ); kws... )
168
182
0 commit comments