Skip to content

Commit 0c736ce

Browse files
faster cat_features
1 parent 6778384 commit 0c736ce

File tree

3 files changed

+51
-35
lines changed

3 files changed

+51
-35
lines changed

src/GNNGraphs/transform.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,12 @@ function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
270270
t = vcat(t1, nv1 .+ t2)
271271
w = cat_features(get_edge_weight(g1), get_edge_weight(g2))
272272
graph = (s, t, w)
273-
ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, Int, nv1) : g1.graph_indicator
274-
ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, Int, nv2) : g2.graph_indicator
273+
ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, nv1) : g1.graph_indicator
274+
ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, nv2) : g2.graph_indicator
275275
elseif g1.graph isa ADJMAT_T
276276
graph = blockdiag(g1.graph, g2.graph)
277-
ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, Int, nv1) : g1.graph_indicator
278-
ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, Int, nv2) : g2.graph_indicator
277+
ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, nv1) : g1.graph_indicator
278+
ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, nv2) : g2.graph_indicator
279279
end
280280
graph_indicator = vcat(ind1, g1.num_graphs .+ ind2)
281281

@@ -358,27 +358,31 @@ julia> g12.ndata.x
358358
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
359359
```
360360
"""
361-
function Flux.batch(gs::Vector{<:GNNGraph})
362-
nodes = [g.num_nodes for g in gs]
363-
364-
if all(y -> isa(y, COO_T), [g.graph for g in gs] )
365-
edge_indices = [edge_index(g) for g in gs]
366-
nodesum = cumsum([0, nodes...])[1:end-1]
367-
s = cat_features([ei[1] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)])
368-
t = cat_features([ei[2] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)])
369-
w = reduce(vcat, [get_edge_weight(g) for g in gs])
370-
w = w isa Vector{Nothing} ? nothing : w
371-
graph = (s, t, w)
372-
graph_indicator = vcat([ones_like(ei[1],Int,nodes[ii]) .+ (ii - 1) for (ii,ei) in enumerate(edge_indices)]...)
373-
elseif all(y -> isa(y, ADJMAT_T), [g.graph for g in gs] )
374-
graph = blockdiag([g.graph for g in gs]...)
375-
graph_indicator = vcat([ones_like(graph,Int,nodes[ii]) .+ (ii - 1) for ii in 1:length(nodes)]...)
361+
Flux.batch(gs::AbstractVector{<:GNNGraph}) = blockdiag(gs...)
362+
363+
function Flux.batch(gs::AbstractVector{<:GNNGraph{T}}) where T<:COO_T
364+
v_num_nodes = [g.num_nodes for g in gs]
365+
edge_indices = [edge_index(g) for g in gs]
366+
nodesum = cumsum([0; v_num_nodes])[1:end-1]
367+
s = cat_features([ei[1] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)])
368+
t = cat_features([ei[2] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)])
369+
w = cat_features([get_edge_weight(g) for g in gs])
370+
graph = (s, t, w)
371+
372+
function materialize_graph_indicator(g)
373+
g.graph_indicator === nothing ? ones_like(s, g.num_nodes) : g.graph_indicator
376374
end
375+
376+
v_gi = materialize_graph_indicator.(gs)
377+
v_num_graphs = [g.num_graphs for g in gs]
378+
graphsum = cumsum([0; v_num_graphs])[1:end-1]
379+
v_gi = [ng .+ gi for (ng, gi) in zip(graphsum, v_gi)]
380+
graph_indicator = cat_features(v_gi)
377381

378382
GNNGraph(graph,
379-
sum(nodes),
383+
sum(v_num_nodes),
380384
sum([g.num_edges for g in gs]),
381-
sum([g.num_graphs for g in gs]),
385+
sum(v_num_graphs),
382386
graph_indicator,
383387
cat_features([g.ndata for g in gs]),
384388
cat_features([g.edata for g in gs]),

src/GNNGraphs/utils.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,16 @@ function cat_features(x1::NamedTuple, x2::NamedTuple)
2727
NamedTuple(k => cat_features(getfield(x1,k), getfield(x2,k)) for k in keys(x1))
2828
end
2929

30-
function cat_features(xs::Vector{<:NamedTuple})
30+
function cat_features(xs::AbstractVector{<:AbstractArray{T,N}}) where {T<:Number, N}
31+
cat(xs...; dims=N)
32+
end
33+
34+
cat_features(xs::AbstractVector{Nothing}) = nothing
35+
cat_features(xs::AbstractVector{<:Number}) = xs
36+
37+
function cat_features(xs::AbstractVector{<:NamedTuple})
3138
symbols = [sort(collect(keys(x))) for x in xs]
32-
all(y->y==symbols[1], symbols) || @error "cannot concatenate feature data with different keys"
39+
all(y -> y==symbols[1], symbols) || @error "cannot concatenate feature data with different keys"
3340
length(xs) == 1 && return xs[1]
3441

3542
# concatenate
@@ -82,9 +89,10 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
8289
return data
8390
end
8491

85-
ones_like(x::AbstractArray, T=eltype(x), sz=size(x)) = fill!(similar(x, T, sz), 1)
86-
ones_like(x::SparseMatrixCSC, T=eltype(x), sz=size(x)) = ones(T, sz)
87-
ones_like(x::CUMAT_T, T=eltype(x), sz=size(x)) = CUDA.ones(T, sz)
92+
ones_like(x::AbstractArray, T::Type, sz=size(x)) = fill!(similar(x, T, sz), 1)
93+
ones_like(x::SparseMatrixCSC, T::Type, sz=size(x)) = ones(T, sz)
94+
ones_like(x::CUMAT_T, T::Type, sz=size(x)) = CUDA.ones(T, sz)
95+
ones_like(x, sz=size(x)) = ones_like(x, eltype(x), sz)
8896

8997
numnonzeros(a::AbstractSparseMatrix) = nnz(a)
9098
numnonzeros(a::AbstractMatrix) = count(!=(0), a)

test/GNNGraphs/transform.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
end
1919

2020
@testset "batch" begin
21-
#TODO add graph_type=GRAPH_T
22-
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10))
23-
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4))
24-
g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7))
21+
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T)
22+
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T)
23+
g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7), graph_type=GRAPH_T)
2524

2625
g12 = Flux.batch([g1, g2])
2726
g12b = blockdiag(g1, g2)
@@ -36,16 +35,21 @@
3635
@test node_features(g123)[:,11:14] node_features(g2)
3736

3837
# scalar graph features
39-
g1 = GNNGraph(random_regular_graph(10,2), gdata=rand())
40-
g2 = GNNGraph(random_regular_graph(4,2), gdata=rand())
41-
g3 = GNNGraph(random_regular_graph(4,2), gdata=rand())
38+
g1 = GNNGraph(g1, gdata=rand())
39+
g2 = GNNGraph(g2, gdata=rand())
40+
g3 = GNNGraph(g3, gdata=rand())
4241
g123 = Flux.batch([g1, g2, g3])
4342
@test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u]
43+
44+
# Batch of batches
45+
g123123 = Flux.batch([g123, g123])
46+
@test g123123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)]
47+
@test g123123.num_graphs == 6
4448
end
4549

4650
@testset "unbatch" begin
47-
g1 = rand_graph(10, 20)
48-
g2 = rand_graph(5, 10)
51+
g1 = rand_graph(10, 20, graph_type=GRAPH_T)
52+
g2 = rand_graph(5, 10, graph_type=GRAPH_T)
4953
g12 = Flux.batch([g1, g2])
5054
gs = Flux.unbatch([g1,g2])
5155
@test length(gs) == 2

0 commit comments

Comments
 (0)