@@ -436,26 +436,44 @@ function remove_self_loops(g::GNNGraph{<:COO_T})
436
436
g. ndata, g. edata, g. gdata)
437
437
end
438
438
439
- function _catgraphs (g1:: GNNGraph{<:COO_T} , g2:: GNNGraph{<:COO_T} )
440
- s1, t1 = edge_index (g1)
441
- s2, t2 = edge_index (g2)
439
+ function SparseArrays. blockdiag (g1:: GNNGraph , g2:: GNNGraph )
442
440
nv1, nv2 = g1. num_nodes, g2. num_nodes
443
- s = vcat (s1, nv1 .+ s2)
444
- t = vcat (t1, nv1 .+ t2)
445
- w = cat_features (edge_weight (g1), edge_weight (g2))
446
-
447
- ind1 = isnothing (g1. graph_indicator) ? fill! (similar (s1, Int, nv1), 1 ) : g1. graph_indicator
448
- ind2 = isnothing (g2. graph_indicator) ? fill! (similar (s2, Int, nv2), 1 ) : g2. graph_indicator
441
+ if g1. graph isa COO_T
442
+ s1, t1 = edge_index (g1)
443
+ s2, t2 = edge_index (g2)
444
+ s = vcat (s1, nv1 .+ s2)
445
+ t = vcat (t1, nv1 .+ t2)
446
+ w = cat_features (edge_weight (g1), edge_weight (g2))
447
+ graph = (s, t, w)
448
+ ind1 = isnothing (g1. graph_indicator) ? ones_like (s1, Int, nv1) : g1. graph_indicator
449
+ ind2 = isnothing (g2. graph_indicator) ? ones_like (s2, Int, nv2) : g2. graph_indicator
450
+ elseif g1. graph isa ADJMAT_T
451
+ graph = blockdiag (g1. graph, g2. graph)
452
+ ind1 = isnothing (g1. graph_indicator) ? ones_like (graph, Int, nv1) : g1. graph_indicator
453
+ ind2 = isnothing (g2. graph_indicator) ? ones_like (graph, Int, nv2) : g2. graph_indicator
454
+ end
449
455
graph_indicator = vcat (ind1, g1. num_graphs .+ ind2)
450
456
451
- GNNGraph ((s, t, w) ,
457
+ GNNGraph (graph ,
452
458
nv1 + nv2, g1. num_edges + g2. num_edges, g1. num_graphs + g2. num_graphs,
453
459
graph_indicator,
454
460
cat_features (g1. ndata, g2. ndata),
455
461
cat_features (g1. edata, g2. edata),
456
462
cat_features (g1. gdata, g2. gdata))
457
463
end
458
464
465
+ # PIRACY
466
+ function SparseArrays. blockdiag (A1:: AbstractMatrix , A2:: AbstractMatrix )
467
+ m1, n1 = size (A1)
468
+ @assert m1 == n1
469
+ m2, n2 = size (A2)
470
+ @assert m2 == n2
471
+ O1 = fill! (similar (A1, eltype (A1), (m1, n2)), 0 )
472
+ O2 = fill! (similar (A1, eltype (A1), (m2, n1)), 0 )
473
+ return [A1 O1
474
+ O2 A2]
475
+ end
476
+
459
477
# ## Cat public interfaces #############
460
478
461
479
"""
@@ -466,7 +484,7 @@ Equivalent to [`Flux.batch`](@ref).
466
484
function SparseArrays. blockdiag (g1:: GNNGraph , gothers:: GNNGraph... )
467
485
g = g1
468
486
for go in gothers
469
- g = _catgraphs (g, go)
487
+ g = blockdiag (g, go)
470
488
end
471
489
return g
472
490
end
@@ -475,39 +493,44 @@ end
475
493
batch(xs::Vector{<:GNNGraph})
476
494
477
495
Batch together multiple `GNNGraph`s into a single one
478
- containing the total number of nodes and edges of the original graphs .
496
+ containing the total number of original nodes and edges.
479
497
480
498
Equivalent to [`SparseArrays.blockdiag`](@ref).
481
499
"""
482
500
Flux. batch (xs:: Vector{<:GNNGraph} ) = blockdiag (xs... )
483
501
484
502
# ## LearnBase compatibility
485
503
LearnBase. nobs (g:: GNNGraph ) = g. num_graphs
486
- LearnBase. getobs (g:: GNNGraph , i) = getgraph (g, i)[ 1 ]
504
+ LearnBase. getobs (g:: GNNGraph , i) = getgraph (g, i)
487
505
488
506
# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683
489
507
Flux. Data. _nobs (g:: GNNGraph ) = g. num_graphs
490
- Flux. Data. _getobs (g:: GNNGraph , i) = getgraph (g, i)[ 1 ]
508
+ Flux. Data. _getobs (g:: GNNGraph , i) = getgraph (g, i)
491
509
492
510
# ########################
493
511
Base.:(== )(g1:: GNNGraph , g2:: GNNGraph ) = all (k -> getfield (g1,k)== getfield (g2,k), fieldnames (typeof (g1)))
494
512
495
513
"""
496
- getgraph(g::GNNGraph, i)
514
+ getgraph(g::GNNGraph, i; nmap=false )
497
515
498
- Return the getgraph of `g` induced by those nodes `v`
499
- for which `g.graph_indicator[v] ∈ i`. In other words, it
500
- extract the component graphs from a batched graph.
516
+ Return the subgraph of `g` induced by those nodes `j`
517
+ for which `g.graph_indicator[j] == i` or,
518
+ if `i` is a collection, `g.graph_indicator[j] ∈ i`.
519
+ In other words, it extract the component graphs from a batched graph.
501
520
502
- It also returns a vector `nodes ` mapping the new nodes to the old ones.
503
- The node `i` in the getgraph corresponds to the node `nodes [i]` in `g`.
521
+ If `nmap=true`, return also a vector `v ` mapping the new nodes to the old ones.
522
+ The node `i` in the subgraph will correspond to the node `v [i]` in `g`.
504
523
"""
505
- getgraph (g:: GNNGraph , i:: Int ) = getgraph (g:: GNNGraph{<:COO_T} , [i])
524
+ getgraph (g:: GNNGraph , i:: Int ; kws ... ) = getgraph (g, [i]; kws ... )
506
525
507
- function getgraph (g:: GNNGraph{<:COO_T} , i:: AbstractVector{Int} )
526
+ function getgraph (g:: GNNGraph , i:: AbstractVector{Int} ; nmap = false )
508
527
if g. graph_indicator === nothing
509
528
@assert i == [1 ]
510
- return g
529
+ if nmap
530
+ return g, 1 : g. num_nodes
531
+ else
532
+ return g
533
+ end
511
534
end
512
535
513
536
node_mask = g. graph_indicator .∈ Ref (i)
@@ -518,25 +541,38 @@ function getgraph(g::GNNGraph{<:COO_T}, i::AbstractVector{Int})
518
541
graphmap = Dict (i => inew for (inew, i) in enumerate (i))
519
542
graph_indicator = [graphmap[i] for i in g. graph_indicator[node_mask]]
520
543
521
- s, t, w = g. graph
522
- edge_mask = s .∈ Ref (nodes)
523
- s = [nodemap[i] for i in s[edge_mask]]
524
- t = [nodemap[i] for i in t[edge_mask]]
525
- w = isnothing (w) ? nothing : w[edge_mask]
526
-
544
+ if g. graph isa COO_T
545
+ s, t = edge_index (g)
546
+ w = edge_weight (g)
547
+ edge_mask = s .∈ Ref (nodes)
548
+ s = [nodemap[i] for i in s[edge_mask]]
549
+ t = [nodemap[i] for i in t[edge_mask]]
550
+ w = isnothing (w) ? nothing : w[edge_mask]
551
+ graph = (s, t, w)
552
+ num_edges = length (s)
553
+ edata = getobs (g. edata, edge_mask)
554
+ elseif g. graph isa ADJMAT_T
555
+ graph = g. graph[nodes, nodes]
556
+ num_edges = count (>= (0 ), graph)
557
+ @assert g. edata == (;) # TODO
558
+ edata = (;)
559
+ end
527
560
ndata = getobs (g. ndata, node_mask)
528
- edata = getobs (g. edata, edge_mask)
529
561
gdata = getobs (g. gdata, i)
530
562
531
563
num_nodes = length (graph_indicator)
532
- num_edges = length (s)
533
564
num_graphs = length (i)
534
565
535
- gnew = GNNGraph ((s,t,w) ,
566
+ gnew = GNNGraph (graph ,
536
567
num_nodes, num_edges, num_graphs,
537
568
graph_indicator,
538
569
ndata, edata, gdata)
539
- return gnew, nodes
570
+
571
+ if nmap
572
+ return gnew, nodes
573
+ else
574
+ return gnew
575
+ end
540
576
end
541
577
542
578
function node_features (g:: GNNGraph )
0 commit comments