Skip to content

Commit 4c225c2

Browse files
authored
Remove type constraints on Flux.batch for GNNHeteroGraph (#342)
* Remove type constraints on Flux.batch for GNNHeteroGraph * Clean up patch * Clean up patch * Finalize revisions to Flux.batch patch * Implement sugguestions * Finalize new Flux.batch test case * Update src/GNNGraphs/datastore.jl * Create setindex! test for DataStore * Update test/GNNGraphs/datastore.jl --------- Co-authored-by: AarSeBail <>
1 parent 066d8f5 commit 4c225c2

File tree

5 files changed

+107
-20
lines changed

5 files changed

+107
-20
lines changed

src/GNNGraphs/datastore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function Base.setproperty!(ds::DataStore, s::Symbol, x)
125125
end
126126

127127
Base.getindex(ds::DataStore, s::Symbol) = getproperty(ds, s)
128-
Base.setindex!(ds::DataStore, s::Symbol, x) = setproperty!(ds, s, x)
128+
Base.setindex!(ds::DataStore, x, s::Symbol) = setproperty!(ds, s, x)
129129

130130
function Base.show(io::IO, ds::DataStore)
131131
len = length(ds)

src/GNNGraphs/transform.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -579,17 +579,26 @@ end
579579

580580

581581
function Flux.batch(gs::AbstractVector{<:GNNHeteroGraph})
582+
function edge_index_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
583+
if haskey(g.graph, edge_t)
584+
g.graph[edge_t][1:2]
585+
else
586+
nothing
587+
end
588+
end
589+
590+
function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
591+
get(g.graph, edge_t, (nothing, nothing, nothing))[3]
592+
end
593+
582594
@assert length(gs) > 0
583595
ntypes = union([g.ntypes for g in gs]...)
584596
etypes = union([g.etypes for g in gs]...)
585-
# TODO remove these constraints
586-
@assert ntypes == gs[1].ntypes
587-
@assert etypes == gs[1].etypes
588597

589-
v_num_nodes = Dict(node_t => [get(g.num_nodes,node_t,0) for g in gs] for node_t in ntypes)
598+
v_num_nodes = Dict(node_t => [get(g.num_nodes, node_t, 0) for g in gs] for node_t in ntypes)
590599
num_nodes = Dict(node_t => sum(v_num_nodes[node_t]) for node_t in ntypes)
591-
num_edges = Dict(edge_t => sum(g.num_edges[edge_t] for g in gs) for edge_t in etypes)
592-
edge_indices = Dict(edge_t => [edge_index(g, edge_t) for g in gs] for edge_t in etypes)
600+
num_edges = Dict(edge_t => sum(get(g.num_edges, edge_t, 0) for g in gs) for edge_t in etypes)
601+
edge_indices = edge_indices = Dict(edge_t => [edge_index_nullable(g, edge_t) for g in gs] for edge_t in etypes)
593602
nodesum = Dict(node_t => cumsum([0; v_num_nodes[node_t]])[1:(end - 1)] for node_t in ntypes)
594603
graphs = []
595604
for edge_t in etypes
@@ -599,34 +608,39 @@ function Flux.batch(gs::AbstractVector{<:GNNHeteroGraph})
599608
# @show ei[1]
600609
# end
601610
# # [ei[1] for (ii, ei) in enumerate(edge_indices[edge_t])]
602-
s = cat_features([ei[1] .+ nodesum[src_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t])])
603-
t = cat_features([ei[2] .+ nodesum[dst_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t])])
604-
w = cat_features([get_edge_weight(g, edge_t) for g in gs])
611+
s = cat_features([ei[1] .+ nodesum[src_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t]) if ei !== nothing])
612+
t = cat_features([ei[2] .+ nodesum[dst_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t]) if ei !== nothing])
613+
w = cat_features(filter(x -> x !== nothing, [get_edge_weight_nullable(g, edge_t) for g in gs]))
605614
push!(graphs, edge_t => (s, t, w))
606615
end
607616
graph = Dict(graphs...)
608617

609618
#TODO relax this restriction
610619
@assert all(g -> g.num_graphs == 1, gs)
611620

612-
s = edge_index(gs[1], etypes[1])[1] # grab any source vector
621+
s = edge_index(gs[1], gs[1].etypes[1])[1] # grab any source vector
613622

614-
function materialize_graph_indicator(g, node_t)
615-
ones_like(s, g.num_nodes[node_t])
623+
function materialize_graph_indicator(g, node_t)
624+
n = get(g.num_nodes, node_t, 0)
625+
return ones_like(s, n)
616626
end
617627
v_gi = Dict(node_t => [materialize_graph_indicator(g, node_t) for g in gs] for node_t in ntypes)
618628
v_num_graphs = [g.num_graphs for g in gs]
619629
graphsum = cumsum([0; v_num_graphs])[1:(end - 1)]
620630
v_gi = Dict(node_t => [ng .+ gi for (ng, gi) in zip(graphsum, v_gi[node_t])] for node_t in ntypes)
621631
graph_indicator = Dict(node_t => cat_features(v_gi[node_t]) for node_t in ntypes)
622632

633+
function data_or_else(data, types)
634+
Dict(type => get(data, type, DataStore(0)) for type in types)
635+
end
636+
623637
return GNNHeteroGraph(graph,
624638
num_nodes,
625639
num_edges,
626640
sum(v_num_graphs),
627641
graph_indicator,
628-
cat_features([g.ndata for g in gs]),
629-
cat_features([g.edata for g in gs]),
642+
cat_features([data_or_else(g.ndata, ntypes) for g in gs]),
643+
cat_features([data_or_else(g.edata, etypes) for g in gs]),
630644
cat_features([g.gdata for g in gs]),
631645
ntypes, etypes)
632646
end

src/GNNGraphs/utils.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,11 @@ end
119119

120120
function cat_features(xs::AbstractVector{<:Dict})
121121
_allkeys = [sort(collect(keys(x))) for x in xs]
122-
_keys = _allkeys[1]
123-
all(y -> y == _keys, _allkeys) ||
124-
@error "cannot concatenate feature data with different keys"
122+
_keys = union(_allkeys...)
125123
length(xs) == 1 && return xs[1]
126124

127125
# concatenate
128-
return Dict([k => cat_features([x[k] for x in xs]) for k in _keys]...)
126+
return Dict([k => cat_features([x[k] for x in xs if haskey(x, k)]) for k in _keys]...)
129127
end
130128

131129

test/GNNGraphs/datastore.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ end
2727
@test vec._data == [Dict(:x => x), Dict(:x => x, :y => vec[2].y)]
2828
end
2929

30+
@testset "setindex!" begin
31+
ds = DataStore(10)
32+
x = rand(10)
33+
@test (ds[:x] = x) == x # Tests setindex!
34+
@test ds.x == ds[:x] == x
35+
end
36+
3037
@testset "map" begin
3138
ds = DataStore(10, (:x => rand(10), :y => rand(2, 10)))
3239
ds2 = map(x -> x .+ 1, ds)

test/GNNGraphs/transform.jl

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,75 @@ end
346346
@test g.ndata[:A].x == ones(2, 50)
347347
@test g.ndata[:A].y == zeros(50)
348348
@test g.edata[(:A,:to,:B)].e == fill(2, 100)
349-
@test g.gdata.u == fill(7, 5)
349+
@test g.gdata.u == fill(7, 5)
350+
end
351+
352+
@testset "batch non-similar edge types" begin
353+
gs = [rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5, (:A, :to1, :B) => 20)),
354+
rand_heterograph((:A => 10, :B => 15), ((:A, :to1, :B) => 5, (:B, :to2, :B) => 16)),
355+
rand_heterograph((:B => 15, :C => 5), ((:C, :to1, :B) => 5, (:B, :to2, :C) => 21)),
356+
rand_heterograph((:A => 10, :B => 10, :C => 10), ((:A, :to1, :C) => 5, (:A, :to1, :B) => 5)),
357+
rand_heterograph((:C => 20), ((:C, :to3, :C) => 10))
358+
]
359+
g = Flux.batch(gs)
360+
361+
@test g.num_nodes[:A] == 10 + 10 + 10
362+
@test g.num_nodes[:B] == 14 + 15 + 15 + 10
363+
@test g.num_nodes[:C] == 5 + 10 + 20
364+
@test g.num_edges[(:A,:to1,:A)] == 5
365+
@test g.num_edges[(:A,:to1,:B)] == 20 + 5 + 5
366+
@test g.num_edges[(:A,:to1,:C)] == 5
367+
368+
@test g.num_edges[(:B,:to2,:B)] == 16
369+
@test g.num_edges[(:B,:to2,:C)] == 21
370+
371+
@test g.num_edges[(:C,:to1,:B)] == 5
372+
@test g.num_edges[(:C,:to3,:C)] == 10
373+
@test length(keys(g.num_edges)) == 7
374+
@test g.num_graphs == 5
375+
376+
function ndata_if_key(g, key, subkey, value)
377+
if haskey(g.ndata, key)
378+
g.ndata[key][subkey] = reduce(hcat, fill(value, g.num_nodes[key]))
379+
end
380+
end
381+
382+
function edata_if_key(g, key, subkey, value)
383+
if haskey(g.edata, key)
384+
g.edata[key][subkey] = reduce(hcat, fill(value, g.num_edges[key]))
385+
end
386+
end
387+
388+
for gi in gs
389+
ndata_if_key(gi, :A, :x, [0])
390+
ndata_if_key(gi, :A, :y, ones(2))
391+
ndata_if_key(gi, :B, :x, ones(3))
392+
ndata_if_key(gi, :C, :y, zeros(4))
393+
edata_if_key(gi, (:A,:to1,:B), :x, [0])
394+
gi.gdata.u = 7
395+
end
396+
397+
g = Flux.batch(gs)
398+
399+
@test g.ndata[:A].x == reduce(hcat, fill(0, 10 + 10 + 10))
400+
@test g.ndata[:A].y == ones(2, 10 + 10 + 10)
401+
@test g.ndata[:B].x == ones(3, 14 + 15 + 15 + 10)
402+
@test g.ndata[:C].y == zeros(4, 5 + 10 + 20)
403+
404+
@test g.edata[(:A,:to1,:B)].x == reduce(hcat, fill(0, 20 + 5 + 5))
405+
406+
@test g.gdata.u == fill(7, 5)
407+
408+
# Allow for wider eltype
409+
g = Flux.batch(GNNHeteroGraph[g for g in gs])
410+
@test g.ndata[:A].x == reduce(hcat, fill(0, 10 + 10 + 10))
411+
@test g.ndata[:A].y == ones(2, 10 + 10 + 10)
412+
@test g.ndata[:B].x == ones(3, 14 + 15 + 15 + 10)
413+
@test g.ndata[:C].y == zeros(4, 5 + 10 + 20)
414+
415+
@test g.edata[(:A,:to1,:B)].x == reduce(hcat, fill(0, 20 + 5 + 5))
416+
417+
@test g.gdata.u == fill(7, 5)
350418
end
351419

352420
@testset "add_edges" begin

0 commit comments

Comments
 (0)