@@ -466,10 +466,78 @@ julia> Flux.unbatch(gbatched)
466
466
num_edges = 2
467
467
```
468
468
"""
469
- function Flux. unbatch (g:: GNNGraph )
470
- [getgraph (g, i) for i in 1 : g. num_graphs]
469
+ function Flux. unbatch (g:: GNNGraph{T} ) where T<: COO_T
470
+ g. num_graphs == 1 && return [g]
471
+
472
+ nodemasks = _unbatch_nodemasks (g. graph_indicator, g. num_graphs)
473
+ num_nodes = length .(nodemasks)
474
+ cumnum_nodes = [0 ; cumsum (num_nodes)]
475
+
476
+ s, t = edge_index (g)
477
+ w = get_edge_weight (g)
478
+
479
+ edgemasks = _unbatch_edgemasks (s, t, g. num_graphs, cumnum_nodes)
480
+ num_edges = length .(edgemasks)
481
+ @assert sum (num_edges) == g. num_edges " Error in unbatching, likely the edges are not sorted (first edges belong to the first graphs, then edges in the second graph and so on)"
482
+
483
+ function build_graph (i)
484
+ node_mask = nodemasks[i]
485
+ edge_mask = edgemasks[i]
486
+ snew = s[edge_mask] .- cumnum_nodes[i]
487
+ tnew = t[edge_mask] .- cumnum_nodes[i]
488
+ wnew = w === nothing ? nothing : w[edge_mask]
489
+ graph = (snew, tnew, wnew)
490
+ graph_indicator = nothing
491
+ ndata = getobs (g. ndata, node_mask)
492
+ edata = getobs (g. edata, edge_mask)
493
+ gdata = getobs (g. gdata, i)
494
+
495
+ nedges = num_edges[i]
496
+ nnodes = num_nodes[i]
497
+ ngraphs = 1
498
+
499
+ return GNNGraph (graph,
500
+ nnodes, nedges, ngraphs,
501
+ graph_indicator,
502
+ ndata, edata, gdata)
503
+ end
504
+
505
+ return [build_graph (i) for i in 1 : g. num_graphs]
506
+ end
507
+
508
+ function Flux. unbatch (g:: GNNGraph )
509
+ return [getgraph (g, i) for i in 1 : g. num_graphs]
510
+ end
511
+
512
+ function _unbatch_nodemasks (graph_indicator, num_graphs)
513
+ @assert issorted (graph_indicator) " The graph_indicator vector must be sorted."
514
+ idxslast = [searchsortedlast (graph_indicator, i) for i in 1 : num_graphs]
515
+
516
+ nodemasks = [1 : idxslast[1 ]]
517
+ for i in 2 : num_graphs
518
+ push! (nodemasks, idxslast[i- 1 ]+ 1 : idxslast[i])
519
+ end
520
+ return nodemasks
521
+ end
522
+
523
+ function _unbatch_edgemasks (s, t, num_graphs, cumnum_nodes)
524
+ edgemasks = []
525
+ for i in 1 : num_graphs- 1
526
+ lastedgeid = findfirst (s) do x
527
+ x > cumnum_nodes[i+ 1 ] && x <= cumnum_nodes[i+ 2 ]
528
+ end
529
+ firstedgeid = i == 1 ? 1 : last (edgemasks[i- 1 ]) + 1
530
+ # if nothing make empty range
531
+ lastedgeid = lastedgeid === nothing ? firstedgeid - 1 : lastedgeid - 1
532
+
533
+ push! (edgemasks, firstedgeid: lastedgeid)
534
+ end
535
+ push! (edgemasks, (last (edgemasks[end ])+ 1 ): length (s))
536
+ return edgemasks
471
537
end
472
538
539
+ @non_differentiable _unbatch_nodemasks (:: Any... )
540
+ @non_differentiable _unbatch_edgemasks (:: Any... )
473
541
474
542
"""
475
543
getgraph(g::GNNGraph, i; nmap=false)
0 commit comments