Skip to content

Commit 1a31606

Browse files
more supprot for getgraph and batch
1 parent 96df778 commit 1a31606

File tree

6 files changed

+92
-46
lines changed

6 files changed

+92
-46
lines changed

docs/src/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ Flux's DataLoader iterates over mini-batches of graphs
7878
(batched together into a `GNNGraph` object).
7979

8080
```julia
81-
gtrain, _ = getgraph(gbatch, 1:800)
82-
gtest, _ = getgraph(gbatch, 801:gbatch.num_graphs)
81+
gtrain = getgraph(gbatch, 1:800)
82+
gtest = getgraph(gbatch, 801:gbatch.num_graphs)
8383
train_loader = Flux.Data.DataLoader(gtrain, batchsize=32, shuffle=true)
8484
test_loader = Flux.Data.DataLoader(gtest, batchsize=32, shuffle=false)
8585

examples/graph_classification_tudataset.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ function train(; kws...)
7575
@info gfull
7676

7777
perm = randperm(gfull.num_graphs)
78-
gtrain, _ = getgraph(gfull, perm[1:NUM_TRAIN])
79-
gtest, _ = getgraph(gfull, perm[NUM_TRAIN+1:end])
78+
gtrain = getgraph(gfull, perm[1:NUM_TRAIN])
79+
gtest = getgraph(gfull, perm[NUM_TRAIN+1:end])
8080
train_loader = DataLoader(gtrain, batchsize=args.batchsize, shuffle=true)
8181
test_loader = DataLoader(gtest, batchsize=args.batchsize, shuffle=false)
8282

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ export
2828
# from LightGraphs
2929
adjacency_matrix,
3030
# from SparseArrays
31-
sprand, sparse,
31+
sprand, sparse, blockdiag,
3232

3333
# msgpass
3434
apply_edges, propagate,

src/gnngraph.jl

Lines changed: 69 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -436,26 +436,44 @@ function remove_self_loops(g::GNNGraph{<:COO_T})
436436
g.ndata, g.edata, g.gdata)
437437
end
438438

439-
function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
440-
s1, t1 = edge_index(g1)
441-
s2, t2 = edge_index(g2)
439+
function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
442440
nv1, nv2 = g1.num_nodes, g2.num_nodes
443-
s = vcat(s1, nv1 .+ s2)
444-
t = vcat(t1, nv1 .+ t2)
445-
w = cat_features(edge_weight(g1), edge_weight(g2))
446-
447-
ind1 = isnothing(g1.graph_indicator) ? fill!(similar(s1, Int, nv1), 1) : g1.graph_indicator
448-
ind2 = isnothing(g2.graph_indicator) ? fill!(similar(s2, Int, nv2), 1) : g2.graph_indicator
441+
if g1.graph isa COO_T
442+
s1, t1 = edge_index(g1)
443+
s2, t2 = edge_index(g2)
444+
s = vcat(s1, nv1 .+ s2)
445+
t = vcat(t1, nv1 .+ t2)
446+
w = cat_features(edge_weight(g1), edge_weight(g2))
447+
graph = (s, t, w)
448+
ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, Int, nv1) : g1.graph_indicator
449+
ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, Int, nv2) : g2.graph_indicator
450+
elseif g1.graph isa ADJMAT_T
451+
graph = blockdiag(g1.graph, g2.graph)
452+
ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, Int, nv1) : g1.graph_indicator
453+
ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, Int, nv2) : g2.graph_indicator
454+
end
449455
graph_indicator = vcat(ind1, g1.num_graphs .+ ind2)
450456

451-
GNNGraph((s, t, w),
457+
GNNGraph(graph,
452458
nv1 + nv2, g1.num_edges + g2.num_edges, g1.num_graphs + g2.num_graphs,
453459
graph_indicator,
454460
cat_features(g1.ndata, g2.ndata),
455461
cat_features(g1.edata, g2.edata),
456462
cat_features(g1.gdata, g2.gdata))
457463
end
458464

465+
# PIRACY
466+
function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix)
467+
m1, n1 = size(A1)
468+
@assert m1 == n1
469+
m2, n2 = size(A2)
470+
@assert m2 == n2
471+
O1 = fill!(similar(A1, eltype(A1), (m1, n2)), 0)
472+
O2 = fill!(similar(A1, eltype(A1), (m2, n1)), 0)
473+
return [A1 O1
474+
O2 A2]
475+
end
476+
459477
### Cat public interfaces #############
460478

461479
"""
@@ -466,7 +484,7 @@ Equivalent to [`Flux.batch`](@ref).
466484
function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
467485
g = g1
468486
for go in gothers
469-
g = _catgraphs(g, go)
487+
g = blockdiag(g, go)
470488
end
471489
return g
472490
end
@@ -475,39 +493,44 @@ end
475493
batch(xs::Vector{<:GNNGraph})
476494
477495
Batch together multiple `GNNGraph`s into a single one
478-
containing the total number of nodes and edges of the original graphs.
496+
containing the total number of original nodes and edges.
479497
480498
Equivalent to [`SparseArrays.blockdiag`](@ref).
481499
"""
482500
Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
483501

484502
### LearnBase compatibility
485503
LearnBase.nobs(g::GNNGraph) = g.num_graphs
486-
LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i)[1]
504+
LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i)
487505

488506
# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683
489507
Flux.Data._nobs(g::GNNGraph) = g.num_graphs
490-
Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i)[1]
508+
Flux.Data._getobs(g::GNNGraph, i) = getgraph(g, i)
491509

492510
#########################
493511
Base.:(==)(g1::GNNGraph, g2::GNNGraph) = all(k -> getfield(g1,k)==getfield(g2,k), fieldnames(typeof(g1)))
494512

495513
"""
496-
getgraph(g::GNNGraph, i)
514+
getgraph(g::GNNGraph, i; nmap=false)
497515
498-
Return the getgraph of `g` induced by those nodes `v`
499-
for which `g.graph_indicator[v] ∈ i`. In other words, it
500-
extract the component graphs from a batched graph.
516+
Return the subgraph of `g` induced by those nodes `j`
517+
for which `g.graph_indicator[j] == i` or,
518+
if `i` is a collection, `g.graph_indicator[j] ∈ i`.
519+
In other words, it extract the component graphs from a batched graph.
501520
502-
It also returns a vector `nodes` mapping the new nodes to the old ones.
503-
The node `i` in the getgraph corresponds to the node `nodes[i]` in `g`.
521+
If `nmap=true`, return also a vector `v` mapping the new nodes to the old ones.
522+
The node `i` in the subgraph will correspond to the node `v[i]` in `g`.
504523
"""
505-
getgraph(g::GNNGraph, i::Int) = getgraph(g::GNNGraph{<:COO_T}, [i])
524+
getgraph(g::GNNGraph, i::Int; kws...) = getgraph(g, [i]; kws...)
506525

507-
function getgraph(g::GNNGraph{<:COO_T}, i::AbstractVector{Int})
526+
function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
508527
if g.graph_indicator === nothing
509528
@assert i == [1]
510-
return g
529+
if nmap
530+
return g, 1:g.num_nodes
531+
else
532+
return g
533+
end
511534
end
512535

513536
node_mask = g.graph_indicator .∈ Ref(i)
@@ -518,25 +541,38 @@ function getgraph(g::GNNGraph{<:COO_T}, i::AbstractVector{Int})
518541
graphmap = Dict(i => inew for (inew, i) in enumerate(i))
519542
graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]]
520543

521-
s, t, w = g.graph
522-
edge_mask = s .∈ Ref(nodes)
523-
s = [nodemap[i] for i in s[edge_mask]]
524-
t = [nodemap[i] for i in t[edge_mask]]
525-
w = isnothing(w) ? nothing : w[edge_mask]
526-
544+
if g.graph isa COO_T
545+
s, t = edge_index(g)
546+
w = edge_weight(g)
547+
edge_mask = s .∈ Ref(nodes)
548+
s = [nodemap[i] for i in s[edge_mask]]
549+
t = [nodemap[i] for i in t[edge_mask]]
550+
w = isnothing(w) ? nothing : w[edge_mask]
551+
graph = (s, t, w)
552+
num_edges = length(s)
553+
edata = getobs(g.edata, edge_mask)
554+
elseif g.graph isa ADJMAT_T
555+
graph = g.graph[nodes, nodes]
556+
num_edges = count(>=(0), graph)
557+
@assert g.edata == (;) # TODO
558+
edata = (;)
559+
end
527560
ndata = getobs(g.ndata, node_mask)
528-
edata = getobs(g.edata, edge_mask)
529561
gdata = getobs(g.gdata, i)
530562

531563
num_nodes = length(graph_indicator)
532-
num_edges = length(s)
533564
num_graphs = length(i)
534565

535-
gnew = GNNGraph((s,t,w),
566+
gnew = GNNGraph(graph,
536567
num_nodes, num_edges, num_graphs,
537568
graph_indicator,
538569
ndata, edata, gdata)
539-
return gnew, nodes
570+
571+
if nmap
572+
return gnew, nodes
573+
else
574+
return gnew
575+
end
540576
end
541577

542578
function node_features(g::GNNGraph)

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
6868
return data
6969
end
7070

71+
ones_like(x::AbstractArray, T=eltype(x), sz=size(x)) = fill!(similar(x, T, sz), 1)
72+
ones_like(x::SparseMatrixCSC, T=eltype(x), sz=size(x)) = ones(T, sz)
73+
ones_like(x::CUMAT_T, T=eltype(x), sz=size(x)) = CUDA.ones(T, sz)
74+
7175
ofeltype(x, y) = convert(float(eltype(x)), y)
7276

7377
# TODO move to flux. fix for https://github.com/FluxML/Flux.jl/issues/1720

test/gnngraph.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,23 @@
141141
end
142142

143143
@testset "getgraph" begin
144-
#TODO add graph_type=GRAPH_T
145-
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10))
146-
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4))
147-
g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7))
144+
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T)
145+
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T)
146+
g3 = GNNGraph(random_regular_graph(7,2), ndata=rand(16,7), graph_type=GRAPH_T)
148147
g = Flux.batch([g1, g2, g3])
149-
g2b, nodemap = getgraph(g, 2)
150148

149+
g2b, nodemap = getgraph(g, 2, nmap=true)
151150
s, t = edge_index(g2b)
152151
@test s == edge_index(g2)[1]
153152
@test t == edge_index(g2)[2]
154153
@test node_features(g2b) node_features(g2)
154+
155+
g2c = getgraph(g, 2)
156+
@test g2c isa GNNGraph{typeof(g.graph)}
157+
158+
g1b, nodemap = getgraph(g1, 1, nmap=true)
159+
@test g1b === g1
160+
@test nodemap == 1:g1.num_nodes
155161
end
156162

157163
@testset "Features" begin
@@ -207,11 +213,11 @@
207213
g = Flux.batch([GNNGraph(erdos_renyi(n, m), ndata=X, edata=E, gdata=U)
208214
for _ in 1:num_graphs])
209215

210-
@test LearnBase.getobs(g, 3) == getgraph(g, 3)[1]
211-
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)[1]
216+
@test LearnBase.getobs(g, 3) == getgraph(g, 3)
217+
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)
212218
@test LearnBase.nobs(g) == g.num_graphs
213219

214220
d = Flux.Data.DataLoader(g, batchsize = 2, shuffle=false)
215-
@test first(d) == getgraph(g, 1:2)[1]
221+
@test first(d) == getgraph(g, 1:2)
216222
end
217223
end

0 commit comments

Comments
 (0)