Skip to content

Commit 661d3e8

Browse files
tclements-usgsCarloLucibello
authored andcommitted
Fix for quadratic batching in #99
1 parent 66b26df commit 661d3e8

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

src/GNNGraphs/transform.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
311311
end
312312
return g
313313
end
314+
SparseArrays.blockdiag(gs::Vector{GNNGraph}) = SparseArrays.blockdiag(gs...)
314315

315316
"""
316317
batch(gs::Vector{<:GNNGraph})
@@ -358,8 +359,33 @@ julia> g12.ndata.x
358359
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
359360
```
360361
"""
361-
Flux.batch(gs::Vector{<:GNNGraph}) = blockdiag(gs...)
362-
362+
function Flux.batch(gs::Vector{<:GNNGraph})
363+
nodes = [g.num_nodes for g in gs]
364+
365+
if all(y -> isa(y, COO_T), [g.graph for g in gs] )
366+
edge_indices = [edge_index(g) for g in gs]
367+
nodesum = cumsum([0, nodes...])[1:end-1]
368+
s = reduce(vcat, [ei[1] .+ nodesum[ii] for (ii,ei) in enumerate(edge_indices)])
369+
t = reduce(vcat, [ei[2] .+ nodesum[ii] for (ii,ei) in enumerate(edge_indices)])
370+
w = reduce(vcat, [get_edge_weight(g) for g in gs])
371+
w = w isa Vector{Nothing} ? nothing : w
372+
graph = (s, t, w)
373+
graph_indicator = vcat([ones_like(ei[1],Int,nodes[ii]) .+ (ii - 1) for (ii,ei) in enumerate(edge_indices)]...)
374+
elseif all(y -> isa(y, ADJMAT_T), [g.graph for g in gs] )
375+
graph = blockdiag([g.graph for g in gs]...)
376+
graph_indicator = vcat([ones_like(graph,Int,nodes[ii]) .+ (ii - 1) for ii in 1:length(nodes)]...)
377+
end
378+
379+
GNNGraph(graph,
380+
sum(nodes),
381+
sum([g.num_edges for g in gs]),
382+
sum([g.num_graphs for g in gs]),
383+
graph_indicator,
384+
cat_features([g.ndata for g in gs]),
385+
cat_features([g.edata for g in gs]),
386+
cat_features([g.gdata for g in gs]),
387+
)
388+
end
363389

364390
"""
365391
unbatch(g::GNNGraph)

src/GNNGraphs/utils.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ function cat_features(x1::NamedTuple, x2::NamedTuple)
2727
NamedTuple(k => cat_features(getfield(x1,k), getfield(x2,k)) for k in keys(x1))
2828
end
2929

30+
function cat_features(xs::Vector{NamedTuple{T1, T2}}) where {T1, T2}
31+
symbols = [sort(collect(keys(x))) for x in xs]
32+
all(y->y==symbols[1], symbols) || @error "cannot concatenate feature data with different keys"
33+
length(xs) == 1 && return xs[1]
34+
35+
# concatenate
36+
syms = symbols[1]
37+
dims = [max(1, ndims(xs[1][k])) for k in syms] # promote scalar to 1D
38+
methods = [dim == 1 ? vcat : hcat for dim in dims] # use optimized reduce(hcat,xs) or reduce(vcat,xs)
39+
NamedTuple(
40+
k => reduce(methods[ii],[x[k] for x in xs]) for (ii,k) in enumerate(syms)
41+
)
42+
end
43+
3044
# Turns generic type into named tuple
3145
normalize_graphdata(data::Nothing; kws...) = NamedTuple()
3246

test/GNNGraphs/transform.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
g12 = Flux.batch([g1, g2])
2727
g12b = blockdiag(g1, g2)
28+
@test g12 == g12b
2829

2930
g123 = Flux.batch([g1, g2, g3])
3031
@test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)]

0 commit comments

Comments
 (0)