Skip to content

Commit ef83dd2

Browse files
faster unbatch (#248)
* faster unbatch * fix all * refactoring * cleanup
1 parent ca17aad commit ef83dd2

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

src/GNNGraphs/transform.jl

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,10 +466,78 @@ julia> Flux.unbatch(gbatched)
466466
num_edges = 2
467467
```
468468
"""
469-
function Flux.unbatch(g::GNNGraph)
470-
[getgraph(g, i) for i in 1:g.num_graphs]
469+
function Flux.unbatch(g::GNNGraph{T}) where T<:COO_T
470+
g.num_graphs == 1 && return [g]
471+
472+
nodemasks = _unbatch_nodemasks(g.graph_indicator, g.num_graphs)
473+
num_nodes = length.(nodemasks)
474+
cumnum_nodes = [0; cumsum(num_nodes)]
475+
476+
s, t = edge_index(g)
477+
w = get_edge_weight(g)
478+
479+
edgemasks = _unbatch_edgemasks(s, t, g.num_graphs, cumnum_nodes)
480+
num_edges = length.(edgemasks)
481+
@assert sum(num_edges) == g.num_edges "Error in unbatching, likely the edges are not sorted (first edges belong to the first graphs, then edges in the second graph and so on)"
482+
483+
function build_graph(i)
484+
node_mask = nodemasks[i]
485+
edge_mask = edgemasks[i]
486+
snew = s[edge_mask] .- cumnum_nodes[i]
487+
tnew = t[edge_mask] .- cumnum_nodes[i]
488+
wnew = w === nothing ? nothing : w[edge_mask]
489+
graph = (snew, tnew, wnew)
490+
graph_indicator = nothing
491+
ndata = getobs(g.ndata, node_mask)
492+
edata = getobs(g.edata, edge_mask)
493+
gdata = getobs(g.gdata, i)
494+
495+
nedges = num_edges[i]
496+
nnodes = num_nodes[i]
497+
ngraphs = 1
498+
499+
return GNNGraph(graph,
500+
nnodes, nedges, ngraphs,
501+
graph_indicator,
502+
ndata, edata, gdata)
503+
end
504+
505+
return [build_graph(i) for i in 1:g.num_graphs]
506+
end
507+
508+
function Flux.unbatch(g::GNNGraph)
509+
return [getgraph(g, i) for i in 1:g.num_graphs]
510+
end
511+
512+
function _unbatch_nodemasks(graph_indicator, num_graphs)
513+
@assert issorted(graph_indicator) "The graph_indicator vector must be sorted."
514+
idxslast = [searchsortedlast(graph_indicator, i) for i in 1:num_graphs]
515+
516+
nodemasks = [1:idxslast[1]]
517+
for i in 2:num_graphs
518+
push!(nodemasks, idxslast[i-1]+1:idxslast[i])
519+
end
520+
return nodemasks
521+
end
522+
523+
function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes)
524+
edgemasks = []
525+
for i in 1:num_graphs-1
526+
lastedgeid = findfirst(s) do x
527+
x > cumnum_nodes[i+1] && x <= cumnum_nodes[i+2]
528+
end
529+
firstedgeid = i == 1 ? 1 : last(edgemasks[i-1]) + 1
530+
# if nothing make empty range
531+
lastedgeid = lastedgeid === nothing ? firstedgeid - 1 : lastedgeid - 1
532+
533+
push!(edgemasks, firstedgeid:lastedgeid)
534+
end
535+
push!(edgemasks, (last(edgemasks[end])+1):length(s))
536+
return edgemasks
471537
end
472538

539+
@non_differentiable _unbatch_nodemasks(::Any...)
540+
@non_differentiable _unbatch_edgemasks(::Any...)
473541

474542
"""
475543
getgraph(g::GNNGraph, i; nmap=false)

test/GNNGraphs/transform.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@
6161
@test gs[2].num_graphs == 1
6262
end
6363

64+
@testset "batch/unbatch roundtrip" begin
65+
n = 20
66+
c = 3
67+
ngraphs = 10
68+
gs = [rand_graph(n, c*n, ndata=rand(2, n), edata=rand(3, c*n), graph_type=GRAPH_T)
69+
for _ in 1:ngraphs]
70+
gall = Flux.batch(gs)
71+
gs2 = Flux.unbatch(gall)
72+
@test gs2[1] == gs[1]
73+
@test gs2[end] == gs[end]
74+
end
75+
6476
@testset "getgraph" begin
6577
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T)
6678
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T)

0 commit comments

Comments
 (0)