Skip to content

Commit ee8be00

Browse files
Flux.batch for GNNHeteroGraph (#309)
* batchhetero * fix all
1 parent 333d74a commit ee8be00

File tree

3 files changed

+102
-6
lines changed

3 files changed

+102
-6
lines changed

src/GNNGraphs/transform.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,60 @@ function Flux.batch(g::GNNGraph)
472472
throw(ArgumentError("Cannot batch a `GNNGraph` (containing $(g.num_graphs) graphs). Pass a vector of `GNNGraph`s instead."))
473473
end
474474

475+
476+
function Flux.batch(gs::AbstractVector{<:GNNHeteroGraph{T}}) where {T <: COO_T}
477+
@assert length(gs) > 0
478+
ntypes = union([g.ntypes for g in gs]...)
479+
etypes = union([g.etypes for g in gs]...)
480+
# TODO remove these constraints
481+
@assert ntypes == gs[1].ntypes
482+
@assert etypes == gs[1].etypes
483+
484+
v_num_nodes = Dict(node_t => [get(g.num_nodes,node_t,0) for g in gs] for node_t in ntypes)
485+
num_nodes = Dict(node_t => sum(v_num_nodes[node_t]) for node_t in ntypes)
486+
num_edges = Dict(edge_t => sum(g.num_edges[edge_t] for g in gs) for edge_t in etypes)
487+
edge_indices = Dict(edge_t => [edge_index(g, edge_t) for g in gs] for edge_t in etypes)
488+
nodesum = Dict(node_t => cumsum([0; v_num_nodes[node_t]])[1:(end - 1)] for node_t in ntypes)
489+
graphs = []
490+
for edge_t in etypes
491+
src_t, _, dst_t = edge_t
492+
# @show edge_t edge_indices[edge_t] first(edge_indices[edge_t])
493+
# for ei in edge_indices[edge_t]
494+
# @show ei[1]
495+
# end
496+
# # [ei[1] for (ii, ei) in enumerate(edge_indices[edge_t])]
497+
s = cat_features([ei[1] .+ nodesum[src_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t])])
498+
t = cat_features([ei[2] .+ nodesum[dst_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t])])
499+
w = cat_features([get_edge_weight(g, edge_t) for g in gs])
500+
push!(graphs, edge_t => (s, t, w))
501+
end
502+
graph = Dict(graphs...)
503+
504+
#TODO relax this restriction
505+
@assert all(g -> g.num_graphs == 1, gs)
506+
507+
s = edge_index(gs[1], etypes[1])[1] # grab any source vector
508+
509+
function materialize_graph_indicator(g, node_t)
510+
ones_like(s, g.num_nodes[node_t])
511+
end
512+
v_gi = Dict(node_t => [materialize_graph_indicator(g, node_t) for g in gs] for node_t in ntypes)
513+
v_num_graphs = [g.num_graphs for g in gs]
514+
graphsum = cumsum([0; v_num_graphs])[1:(end - 1)]
515+
v_gi = Dict(node_t => [ng .+ gi for (ng, gi) in zip(graphsum, v_gi[node_t])] for node_t in ntypes)
516+
graph_indicator = Dict(node_t => cat_features(v_gi[node_t]) for node_t in ntypes)
517+
518+
GNNHeteroGraph(graph,
519+
num_nodes,
520+
num_edges,
521+
sum(v_num_graphs),
522+
graph_indicator,
523+
cat_features([g.ndata for g in gs]),
524+
cat_features([g.edata for g in gs]),
525+
cat_features([g.gdata for g in gs]),
526+
ntypes, etypes)
527+
end
528+
475529
"""
476530
unbatch(g::GNNGraph)
477531

src/GNNGraphs/utils.jl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,14 @@ function cat_features(x1::Dict{Symbol, T}, x2::Dict{Symbol, T}) where {T}
8383
sort(collect(keys(x1))) == sort(collect(keys(x2))) ||
8484
@error "cannot concatenate feature data with different keys"
8585

86-
return Dict{Symbol, T}(k => cat_features(x1[k], x2[k]) for k in keys(x1))
86+
return Dict{Symbol, T}([k => cat_features(x1[k], x2[k]) for k in keys(x1)]...)
8787
end
8888

89+
function cat_features(x::Dict)
90+
return Dict([k => cat_features(v) for (k, v) in pairs(x)]...)
91+
end
92+
93+
8994
function cat_features(xs::AbstractVector{<:AbstractArray{T, N}}) where {T <: Number, N}
9095
cat(xs...; dims = N)
9196
end
@@ -104,17 +109,30 @@ function cat_features(xs::AbstractVector{<:NamedTuple})
104109
NamedTuple(k => cat_features([x[k] for x in xs]) for k in syms)
105110
end
106111

107-
function cat_features(xs::AbstractVector{Dict{Symbol, T}}) where {T}
108-
symbols = [sort(collect(keys(x))) for x in xs]
109-
all(y -> y == symbols[1], symbols) ||
112+
# function cat_features(xs::AbstractVector{Dict{Symbol, T}}) where {T}
113+
# symbols = [sort(collect(keys(x))) for x in xs]
114+
# all(y -> y == symbols[1], symbols) ||
115+
# @error "cannot concatenate feature data with different keys"
116+
# length(xs) == 1 && return xs[1]
117+
118+
# # concatenate
119+
# syms = symbols[1]
120+
# return Dict{Symbol, T}([k => cat_features([x[k] for x in xs]) for k in syms]...)
121+
# end
122+
123+
function cat_features(xs::AbstractVector{<:Dict})
124+
_allkeys = [sort(collect(keys(x))) for x in xs]
125+
_keys = _allkeys[1]
126+
all(y -> y == _keys, _allkeys) ||
110127
@error "cannot concatenate feature data with different keys"
111128
length(xs) == 1 && return xs[1]
112129

113130
# concatenate
114-
syms = symbols[1]
115-
return Dict{Symbol, T}(k => cat_features([x[k] for x in xs]) for k in syms)
131+
return Dict([k => cat_features([x[k] for x in xs]) for k in _keys]...)
116132
end
117133

134+
135+
118136
# Turns generic type into named tuple
119137
normalize_graphdata(data::Nothing; n, kws...) = DataStore(n)
120138

test/GNNGraphs/transform.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,27 @@ end
292292
0.5 1.0 0.5
293293
0.0 0.0 0.0]
294294
end
295+
296+
@testset "batch heterograph" begin
297+
gs = [rand_bipartite_heterograph((10, 15), 20) for _ in 1:5]
298+
g = Flux.batch(gs)
299+
@test g.num_nodes[:A] == 50
300+
@test g.num_nodes[:B] == 75
301+
@test g.num_edges[(:A,:to,:B)] == 100
302+
@test g.num_edges[(:B,:to,:A)] == 100
303+
@test g.num_graphs == 5
304+
@test g.graph_indicator == Dict(:A => vcat([fill(i, 10) for i in 1:5]...),
305+
:B => vcat([fill(i, 15) for i in 1:5]...))
306+
307+
for gi in gs
308+
gi.ndata[:A].x = ones(2, 10)
309+
gi.ndata[:A].y = zeros(10)
310+
gi.edata[(:A,:to,:B)].e = fill(2, 20)
311+
gi.gdata.u = 7
312+
end
313+
g = Flux.batch(gs)
314+
@test g.ndata[:A].x == ones(2, 50)
315+
@test g.ndata[:A].y == zeros(50)
316+
@test g.edata[(:A,:to,:B)].e == fill(2, 100)
317+
@test g.gdata.u == fill(7, 5)
318+
end

0 commit comments

Comments
 (0)