Skip to content

Commit 066d8f5

Browse files
batch wider eltype (#340)
1 parent 108843a commit 066d8f5

File tree

2 files changed

+101
-78
lines changed

2 files changed

+101
-78
lines changed

src/GNNGraphs/transform.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,16 @@ julia> g12.ndata.x
533533
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
534534
```
535535
"""
536-
Flux.batch(gs::AbstractVector{<:GNNGraph}) = blockdiag(gs...)
536+
function Flux.batch(gs::AbstractVector{<:GNNGraph})
537+
Told = eltype(gs)
538+
# try to restrict the eltype
539+
gs = [g for g in gs]
540+
if eltype(gs) != Told
541+
return Flux.batch(gs)
542+
else
543+
return blockdiag(gs...)
544+
end
545+
end
537546

538547
function Flux.batch(gs::AbstractVector{<:GNNGraph{T}}) where {T <: COO_T}
539548
v_num_nodes = [g.num_nodes for g in gs]
@@ -569,7 +578,7 @@ function Flux.batch(g::GNNGraph)
569578
end
570579

571580

572-
function Flux.batch(gs::AbstractVector{<:GNNHeteroGraph{T}}) where {T <: COO_T}
581+
function Flux.batch(gs::AbstractVector{<:GNNHeteroGraph})
573582
@assert length(gs) > 0
574583
ntypes = union([g.ntypes for g in gs]...)
575584
etypes = union([g.etypes for g in gs]...)
@@ -611,15 +620,15 @@ function Flux.batch(gs::AbstractVector{<:GNNHeteroGraph{T}}) where {T <: COO_T}
611620
v_gi = Dict(node_t => [ng .+ gi for (ng, gi) in zip(graphsum, v_gi[node_t])] for node_t in ntypes)
612621
graph_indicator = Dict(node_t => cat_features(v_gi[node_t]) for node_t in ntypes)
613622

614-
GNNHeteroGraph(graph,
615-
num_nodes,
616-
num_edges,
617-
sum(v_num_graphs),
618-
graph_indicator,
619-
cat_features([g.ndata for g in gs]),
620-
cat_features([g.edata for g in gs]),
621-
cat_features([g.gdata for g in gs]),
622-
ntypes, etypes)
623+
return GNNHeteroGraph(graph,
624+
num_nodes,
625+
num_edges,
626+
sum(v_num_graphs),
627+
graph_indicator,
628+
cat_features([g.ndata for g in gs]),
629+
cat_features([g.edata for g in gs]),
630+
cat_features([g.gdata for g in gs]),
631+
ntypes, etypes)
623632
end
624633

625634
"""

test/GNNGraphs/transform.jl

Lines changed: 81 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ end
2929
g123 = Flux.batch([g1, g2, g3])
3030
@test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)]
3131

32+
# Allow wider eltype
33+
g123 = Flux.batch(GNNGraph[g1, g2, g3])
34+
@test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)]
35+
36+
3237
s, t = edge_index(g123)
3338
@test s == [edge_index(g1)[1]; 10 .+ edge_index(g2)[1]; 14 .+ edge_index(g3)[1]]
3439
@test t == [edge_index(g1)[2]; 10 .+ edge_index(g2)[2]; 14 .+ edge_index(g3)[2]]
@@ -141,52 +146,6 @@ end
141146
gnew = add_edges(g, (snew, tnew, wnew))
142147
@test get_edge_weight(gnew) == [w; wnew]
143148
end
144-
145-
@testset "heterograph" begin
146-
hg = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false)
147-
hg = add_edges(hg, (:B,:to,:A), [1, 1], [1,2])
148-
@test hg.num_edges == Dict((:A,:to,:B) => 4, (:B,:to,:A) => 2)
149-
@test has_edge(hg, (:B,:to,:A), 1, 1)
150-
@test has_edge(hg, (:B,:to,:A), 1, 2)
151-
@test !has_edge(hg, (:B,:to,:A), 2, 1)
152-
@test !has_edge(hg, (:B,:to,:A), 2, 2)
153-
154-
@testset "new nodes" begin
155-
hg = rand_bipartite_heterograph((2, 2), 3)
156-
hg = add_edges(hg, (:C,:rel,:B) => ([1, 3], [1,2]))
157-
@test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3)
158-
@test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2)
159-
s, t = edge_index(hg, (:C,:rel,:B))
160-
@test s == [1, 3]
161-
@test t == [1, 2]
162-
163-
hg = add_edges(hg, (:D,:rel,:F) => ([1, 3], [1,2]))
164-
@test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3, :D => 3, :F => 2)
165-
@test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2, (:D,:rel,:F) => 2)
166-
s, t = edge_index(hg, (:D,:rel,:F))
167-
@test s == [1, 3]
168-
@test t == [1, 2]
169-
end
170-
171-
@testset "also add weights" begin
172-
hg = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7], [0.1, 0.2, 0.3, 0.4]))
173-
hgnew = add_edges(hg, (:user, :like, :actor) => ([1, 2], [3, 4], [0.5, 0.6]))
174-
@test hgnew.num_nodes[:user] == 3
175-
@test hgnew.num_nodes[:movie] == 13
176-
@test hgnew.num_nodes[:actor] == 4
177-
@test hgnew.num_edges == Dict((:user, :rate, :movie) => 4, (:user, :like, :actor) => 2)
178-
@test get_edge_weight(hgnew, (:user, :rate, :movie)) == [0.1, 0.2, 0.3, 0.4]
179-
@test get_edge_weight(hgnew, (:user, :like, :actor)) == [0.5, 0.6]
180-
181-
hgnew2 = add_edges(hgnew, (:user, :like, :actor) => ([6, 7], [8, 10], [0.7, 0.8]))
182-
@test hgnew2.num_nodes[:user] == 7
183-
@test hgnew2.num_nodes[:movie] == 13
184-
@test hgnew2.num_nodes[:actor] == 10
185-
@test hgnew2.num_edges == Dict((:user, :rate, :movie) => 4, (:user, :like, :actor) => 4)
186-
@test get_edge_weight(hgnew2, (:user, :rate, :movie)) == [0.1, 0.2, 0.3, 0.4]
187-
@test get_edge_weight(hgnew2, (:user, :like, :actor)) == [0.5, 0.6, 0.7, 0.8]
188-
end
189-
end
190149
end
191150
end
192151

@@ -358,26 +317,81 @@ end
358317
0.0 0.0 0.0]
359318
end
360319

361-
@testset "batch heterograph" begin
362-
gs = [rand_bipartite_heterograph((10, 15), 20) for _ in 1:5]
363-
g = Flux.batch(gs)
364-
@test g.num_nodes[:A] == 50
365-
@test g.num_nodes[:B] == 75
366-
@test g.num_edges[(:A,:to,:B)] == 100
367-
@test g.num_edges[(:B,:to,:A)] == 100
368-
@test g.num_graphs == 5
369-
@test g.graph_indicator == Dict(:A => vcat([fill(i, 10) for i in 1:5]...),
370-
:B => vcat([fill(i, 15) for i in 1:5]...))
371-
372-
for gi in gs
373-
gi.ndata[:A].x = ones(2, 10)
374-
gi.ndata[:A].y = zeros(10)
375-
gi.edata[(:A,:to,:B)].e = fill(2, 20)
376-
gi.gdata.u = 7
320+
@testset "HeteroGraphs" begin
321+
@testset "batch" begin
322+
gs = [rand_bipartite_heterograph((10, 15), 20) for _ in 1:5]
323+
g = Flux.batch(gs)
324+
@test g.num_nodes[:A] == 50
325+
@test g.num_nodes[:B] == 75
326+
@test g.num_edges[(:A,:to,:B)] == 100
327+
@test g.num_edges[(:B,:to,:A)] == 100
328+
@test g.num_graphs == 5
329+
@test g.graph_indicator == Dict(:A => vcat([fill(i, 10) for i in 1:5]...),
330+
:B => vcat([fill(i, 15) for i in 1:5]...))
331+
332+
for gi in gs
333+
gi.ndata[:A].x = ones(2, 10)
334+
gi.ndata[:A].y = zeros(10)
335+
gi.edata[(:A,:to,:B)].e = fill(2, 20)
336+
gi.gdata.u = 7
337+
end
338+
g = Flux.batch(gs)
339+
@test g.ndata[:A].x == ones(2, 50)
340+
@test g.ndata[:A].y == zeros(50)
341+
@test g.edata[(:A,:to,:B)].e == fill(2, 100)
342+
@test g.gdata.u == fill(7, 5)
343+
344+
# Allow for wider eltype
345+
g = Flux.batch(GNNHeteroGraph[g for g in gs])
346+
@test g.ndata[:A].x == ones(2, 50)
347+
@test g.ndata[:A].y == zeros(50)
348+
@test g.edata[(:A,:to,:B)].e == fill(2, 100)
349+
@test g.gdata.u == fill(7, 5)
350+
end
351+
352+
@testset "add_edges" begin
353+
hg = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false)
354+
hg = add_edges(hg, (:B,:to,:A), [1, 1], [1,2])
355+
@test hg.num_edges == Dict((:A,:to,:B) => 4, (:B,:to,:A) => 2)
356+
@test has_edge(hg, (:B,:to,:A), 1, 1)
357+
@test has_edge(hg, (:B,:to,:A), 1, 2)
358+
@test !has_edge(hg, (:B,:to,:A), 2, 1)
359+
@test !has_edge(hg, (:B,:to,:A), 2, 2)
360+
361+
@testset "new nodes" begin
362+
hg = rand_bipartite_heterograph((2, 2), 3)
363+
hg = add_edges(hg, (:C,:rel,:B) => ([1, 3], [1,2]))
364+
@test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3)
365+
@test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2)
366+
s, t = edge_index(hg, (:C,:rel,:B))
367+
@test s == [1, 3]
368+
@test t == [1, 2]
369+
370+
hg = add_edges(hg, (:D,:rel,:F) => ([1, 3], [1,2]))
371+
@test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3, :D => 3, :F => 2)
372+
@test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2, (:D,:rel,:F) => 2)
373+
s, t = edge_index(hg, (:D,:rel,:F))
374+
@test s == [1, 3]
375+
@test t == [1, 2]
376+
end
377+
378+
@testset "also add weights" begin
379+
hg = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7], [0.1, 0.2, 0.3, 0.4]))
380+
hgnew = add_edges(hg, (:user, :like, :actor) => ([1, 2], [3, 4], [0.5, 0.6]))
381+
@test hgnew.num_nodes[:user] == 3
382+
@test hgnew.num_nodes[:movie] == 13
383+
@test hgnew.num_nodes[:actor] == 4
384+
@test hgnew.num_edges == Dict((:user, :rate, :movie) => 4, (:user, :like, :actor) => 2)
385+
@test get_edge_weight(hgnew, (:user, :rate, :movie)) == [0.1, 0.2, 0.3, 0.4]
386+
@test get_edge_weight(hgnew, (:user, :like, :actor)) == [0.5, 0.6]
387+
388+
hgnew2 = add_edges(hgnew, (:user, :like, :actor) => ([6, 7], [8, 10], [0.7, 0.8]))
389+
@test hgnew2.num_nodes[:user] == 7
390+
@test hgnew2.num_nodes[:movie] == 13
391+
@test hgnew2.num_nodes[:actor] == 10
392+
@test hgnew2.num_edges == Dict((:user, :rate, :movie) => 4, (:user, :like, :actor) => 4)
393+
@test get_edge_weight(hgnew2, (:user, :rate, :movie)) == [0.1, 0.2, 0.3, 0.4]
394+
@test get_edge_weight(hgnew2, (:user, :like, :actor)) == [0.5, 0.6, 0.7, 0.8]
395+
end
377396
end
378-
g = Flux.batch(gs)
379-
@test g.ndata[:A].x == ones(2, 50)
380-
@test g.ndata[:A].y == zeros(50)
381-
@test g.edata[(:A,:to,:B)].e == fill(2, 100)
382-
@test g.gdata.u == fill(7, 5)
383397
end

0 commit comments

Comments
 (0)