@@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix
11
11
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
12
12
13
13
"""
14
- GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir])
14
+ GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, graph_indicator, dir])
15
15
GNNGraph(g::GNNGraph; [nf, ef, gf])
16
16
17
17
A type representing a graph structure and storing also arrays
@@ -50,7 +50,6 @@ from the LightGraphs' graph library can be used on it.
50
50
- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
51
51
Possible values are `:out` and `:in`. Default `:out`.
52
52
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
53
- - `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`.
54
53
- `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
55
54
- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
56
55
- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
@@ -123,17 +122,17 @@ function GNNGraph(data;
123
122
124
123
@assert graph_type ∈ [:coo , :dense , :sparse ] " Invalid graph_type $graph_type requested"
125
124
@assert dir ∈ [:in , :out ]
125
+
126
126
if graph_type == :coo
127
127
g, num_nodes, num_edges = to_coo (data; num_nodes, dir)
128
128
elseif graph_type == :dense
129
129
g, num_nodes, num_edges = to_dense (data; dir)
130
130
elseif graph_type == :sparse
131
131
g, num_nodes, num_edges = to_sparse (data; dir)
132
132
end
133
- if num_graphs > 1
134
- @assert len (graph_indicator) = num_nodes " When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships."
135
- end
136
-
133
+
134
+ num_graphs = ! isnothing (graph_indicator) ? maximum (graph_indicator) : 1
135
+
137
136
# # Possible future implementation of feature maps.
138
137
# # Currently this doesn't play well with zygote due to
139
138
# # https://github.com/FluxML/Zygote.jl/issues/717
@@ -154,8 +153,8 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
154
153
155
154
function GNNGraph (g:: AbstractGraph ; kws... )
156
155
s = LightGraphs. src .(LightGraphs. edges (g))
157
- t = LightGraphs. dst .(LightGraphs. edges (g))
158
- GNNGraph ((s, t); kws... )
156
+ t = LightGraphs. dst .(LightGraphs. edges (g))
157
+ GNNGraph ((s, t); num_nodes = nv (g), kws... )
159
158
end
160
159
161
160
function GNNGraph (g:: GNNGraph ;
@@ -436,36 +435,77 @@ function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
436
435
)
437
436
end
438
437
439
- # Cat public interfaces
438
+ # ## Cat public interfaces #############
440
439
441
- ```
440
+ """
442
441
blockdiag(xs::GNNGraph...)
443
442
444
443
Batch togheter multiple `GNNGraph`s into a single one
445
444
containing the total number of nodes and edges of the original graphs.
446
445
447
446
Equivalent to [`Flux.batch`](@ref).
448
- ```
447
+ """
449
448
function SparseArrays. blockdiag (g1:: GNNGraph , gothers:: GNNGraph... )
450
- @assert length (gothers) >= 1
451
449
g = g1
452
450
for go in gothers
453
451
g = _catgraphs (g, go)
454
452
end
455
453
return g
456
454
end
457
455
458
- ```
456
+ """
459
457
batch(xs::Vector{<:GNNGraph})
460
458
461
459
Batch togheter multiple `GNNGraph`s into a single one
462
460
containing the total number of nodes and edges of the original graphs.
463
461
464
462
Equivalent to [`SparseArrays.blockdiag`](@ref).
465
- ```
463
+ """
466
464
Flux. batch (xs:: Vector{<:GNNGraph} ) = blockdiag (xs... )
467
465
# ########################
468
466
467
+ """
468
+ subgraph(g::GNNGraph, i)
469
+
470
+ Return the subgraph of `g` induced by those nodes `v`
471
+ for which `g.graph_indicator[v] ∈ i`. In other words, it
472
+ extract the component graphs from a batched graph.
473
+
474
+ It also returns a vector `nodes` mapping the new nodes to the old ones.
475
+ The node `i` in the subgraph corresponds to the node `nodes[i]` in `g`.
476
+ """
477
+ subgraph (g:: GNNGraph , i:: Int ) = subgraph (g:: GNNGraph{<:COO_T} , [i])
478
+
479
+ function subgraph (g:: GNNGraph{<:COO_T} , i:: AbstractVector )
480
+ node_mask = g. graph_indicator .∈ Ref (i)
481
+
482
+ nodes = (1 : g. num_nodes)[node_mask]
483
+ nodemap = Dict (v => vnew for (vnew, v) in enumerate (nodes))
484
+
485
+ graphmap = Dict (i => inew for (inew, i) in enumerate (i))
486
+ graph_indicator = [graphmap[i] for i in g. graph_indicator[node_mask]]
487
+
488
+ s, t, w = g. graph
489
+ edge_mask = s .∈ Ref (nodes)
490
+ s = [nodemap[i] for i in s[edge_mask]]
491
+ t = [nodemap[i] for i in t[edge_mask]]
492
+ w = isnothing (w) ? nothing : w[edge_mask]
493
+ @show size (g. nf) size (node_mask)
494
+ nf = isnothing (g. nf) ? nothing : g. nf[:,node_mask]
495
+ ef = isnothing (g. ef) ? nothing : g. ef[:,edge_mask]
496
+ gf = isnothing (g. gf) ? nothing : g. gf[:,i]
497
+
498
+ num_nodes = length (graph_indicator)
499
+ num_edges = length (s)
500
+ num_graphs = length (i)
501
+
502
+ gnew = GNNGraph ((s,t,w),
503
+ num_nodes, num_edges, num_graphs,
504
+ graph_indicator,
505
+ nf, ef, gf)
506
+ return gnew, nodes
507
+ end
508
+
469
509
@non_differentiable normalized_laplacian (x... )
470
510
@non_differentiable normalized_adjacency (x... )
471
511
@non_differentiable scaled_laplacian (x... )
0 commit comments