Skip to content

Commit be95629

Browse files
Merge pull request #122 from CarloLucibello/tclements/master
faster batching
2 parents 66b26df + f477827 commit be95629

File tree

3 files changed

+72
-17
lines changed

3 files changed

+72
-17
lines changed

src/GNNGraphs/transform.jl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,12 @@ function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
270270
t = vcat(t1, nv1 .+ t2)
271271
w = cat_features(get_edge_weight(g1), get_edge_weight(g2))
272272
graph = (s, t, w)
273-
ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, Int, nv1) : g1.graph_indicator
274-
ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, Int, nv2) : g2.graph_indicator
273+
ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, nv1) : g1.graph_indicator
274+
ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, nv2) : g2.graph_indicator
275275
elseif g1.graph isa ADJMAT_T
276276
graph = blockdiag(g1.graph, g2.graph)
277-
ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, Int, nv1) : g1.graph_indicator
278-
ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, Int, nv2) : g2.graph_indicator
277+
ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, nv1) : g1.graph_indicator
278+
ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, nv2) : g2.graph_indicator
279279
end
280280
graph_indicator = vcat(ind1, g1.num_graphs .+ ind2)
281281

@@ -358,8 +358,37 @@ julia> g12.ndata.x
358358
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
359359
```
360360
"""
361-
Flux.batch(gs::Vector{<:GNNGraph}) = blockdiag(gs...)
361+
Flux.batch(gs::AbstractVector{<:GNNGraph}) = blockdiag(gs...)
362+
363+
function Flux.batch(gs::AbstractVector{<:GNNGraph{T}}) where T<:COO_T
364+
v_num_nodes = [g.num_nodes for g in gs]
365+
edge_indices = [edge_index(g) for g in gs]
366+
nodesum = cumsum([0; v_num_nodes])[1:end-1]
367+
s = cat_features([ei[1] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)])
368+
t = cat_features([ei[2] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)])
369+
w = cat_features([get_edge_weight(g) for g in gs])
370+
graph = (s, t, w)
371+
372+
function materialize_graph_indicator(g)
373+
g.graph_indicator === nothing ? ones_like(s, g.num_nodes) : g.graph_indicator
374+
end
362375

376+
v_gi = materialize_graph_indicator.(gs)
377+
v_num_graphs = [g.num_graphs for g in gs]
378+
graphsum = cumsum([0; v_num_graphs])[1:end-1]
379+
v_gi = [ng .+ gi for (ng, gi) in zip(graphsum, v_gi)]
380+
graph_indicator = cat_features(v_gi)
381+
382+
GNNGraph(graph,
383+
sum(v_num_nodes),
384+
sum([g.num_edges for g in gs]),
385+
sum(v_num_graphs),
386+
graph_indicator,
387+
cat_features([g.ndata for g in gs]),
388+
cat_features([g.edata for g in gs]),
389+
cat_features([g.gdata for g in gs]),
390+
)
391+
end
363392

364393
"""
365394
unbatch(g::GNNGraph)

src/GNNGraphs/utils.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,33 @@ cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector
2020

2121
# workaround for issue #98 #104
2222
cat_features(x1::NamedTuple{(), Tuple{}}, x2::NamedTuple{(), Tuple{}}) = (;)
23+
cat_features(xs::AbstractVector{NamedTuple{(), Tuple{}}}) = (;)
2324

2425
function cat_features(x1::NamedTuple, x2::NamedTuple)
2526
sort(collect(keys(x1))) == sort(collect(keys(x2))) || @error "cannot concatenate feature data with different keys"
2627

2728
NamedTuple(k => cat_features(getfield(x1,k), getfield(x2,k)) for k in keys(x1))
2829
end
2930

31+
function cat_features(xs::AbstractVector{<:AbstractArray{T,N}}) where {T<:Number, N}
32+
cat(xs...; dims=N)
33+
end
34+
35+
cat_features(xs::AbstractVector{Nothing}) = nothing
36+
cat_features(xs::AbstractVector{<:Number}) = xs
37+
38+
function cat_features(xs::AbstractVector{<:NamedTuple})
39+
symbols = [sort(collect(keys(x))) for x in xs]
40+
all(y -> y==symbols[1], symbols) || @error "cannot concatenate feature data with different keys"
41+
length(xs) == 1 && return xs[1]
42+
43+
# concatenate
44+
syms = symbols[1]
45+
NamedTuple(
46+
k => cat_features([x[k] for x in xs]) for (ii,k) in enumerate(syms)
47+
)
48+
end
49+
3050
# Turns generic type into named tuple
3151
normalize_graphdata(data::Nothing; kws...) = NamedTuple()
3252

@@ -70,9 +90,10 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
7090
return data
7191
end
7292

73-
ones_like(x::AbstractArray, T=eltype(x), sz=size(x)) = fill!(similar(x, T, sz), 1)
74-
ones_like(x::SparseMatrixCSC, T=eltype(x), sz=size(x)) = ones(T, sz)
75-
ones_like(x::CUMAT_T, T=eltype(x), sz=size(x)) = CUDA.ones(T, sz)
93+
ones_like(x::AbstractArray, T::Type, sz=size(x)) = fill!(similar(x, T, sz), 1)
94+
ones_like(x::SparseMatrixCSC, T::Type, sz=size(x)) = ones(T, sz)
95+
ones_like(x::CUMAT_T, T::Type, sz=size(x)) = CUDA.ones(T, sz)
96+
ones_like(x, sz=size(x)) = ones_like(x, eltype(x), sz)
7697

7798
numnonzeros(a::AbstractSparseMatrix) = nnz(a)
7899
numnonzeros(a::AbstractMatrix) = count(!=(0), a)

test/GNNGraphs/transform.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
end
1919

2020
@testset "batch" begin
21-
#TODO add graph_type=GRAPH_T
22-
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10))
23-
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4))
24-
g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7))
21+
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T)
22+
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T)
23+
g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7), graph_type=GRAPH_T)
2524

2625
g12 = Flux.batch([g1, g2])
2726
g12b = blockdiag(g1, g2)
27+
@test g12 == g12b
2828

2929
g123 = Flux.batch([g1, g2, g3])
3030
@test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)]
@@ -35,16 +35,21 @@
3535
@test node_features(g123)[:,11:14] node_features(g2)
3636

3737
# scalar graph features
38-
g1 = GNNGraph(random_regular_graph(10,2), gdata=rand())
39-
g2 = GNNGraph(random_regular_graph(4,2), gdata=rand())
40-
g3 = GNNGraph(random_regular_graph(4,2), gdata=rand())
38+
g1 = GNNGraph(g1, gdata=rand())
39+
g2 = GNNGraph(g2, gdata=rand())
40+
g3 = GNNGraph(g3, gdata=rand())
4141
g123 = Flux.batch([g1, g2, g3])
4242
@test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u]
43+
44+
# Batch of batches
45+
g123123 = Flux.batch([g123, g123])
46+
@test g123123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)]
47+
@test g123123.num_graphs == 6
4348
end
4449

4550
@testset "unbatch" begin
46-
g1 = rand_graph(10, 20)
47-
g2 = rand_graph(5, 10)
51+
g1 = rand_graph(10, 20, graph_type=GRAPH_T)
52+
g2 = rand_graph(5, 10, graph_type=GRAPH_T)
4853
g12 = Flux.batch([g1, g2])
4954
gs = Flux.unbatch([g1,g2])
5055
@test length(gs) == 2

0 commit comments

Comments
 (0)