Skip to content

Commit 532caa6

Browse files
[heterograph] add has_edge, get_index, add_edges (#298)
* [heterograph] has_edge and get_index * add_edges * cleanup
1 parent 2217bf0 commit 532caa6

File tree

10 files changed

+210
-40
lines changed

10 files changed

+210
-40
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ makedocs(;
2424
"Message Passing" => "messagepassing.md",
2525
"Model Building" => "models.md",
2626
"Datasets" => "datasets.md",
27-
# "Tutorials" => tutorials,
27+
"Tutorials" => tutorials,
2828
"API Reference" => [
2929
"GNNGraph" => "api/gnngraph.md",
3030
"Basic Layers" => "api/basic.md",

src/GNNGraphs/GNNGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ export add_nodes,
8181
include("generate.jl")
8282
export rand_graph,
8383
rand_heterograph,
84+
rand_bipartite_heterograph,
8485
knn_graph,
8586
radius_graph
8687

src/GNNGraphs/convert.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
### CONVERT_TO_COO REPRESENTATION ########
22

33
function to_coo(data::EDict; num_nodes = nothing, kws...)
4-
graph = EDict{Any}()
4+
graph = EDict{COO_T}()
55
_num_nodes = NDict{Int}()
66
num_edges = EDict{Int}()
77
for k in keys(data)
@@ -23,6 +23,7 @@ function to_coo(data::EDict; num_nodes = nothing, kws...)
2323
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
2424
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
2525
end
26+
graph = Dict(k => v for (k, v) in pairs(graph)) # try to restrict the key/value types
2627
return graph, _num_nodes, num_edges
2728
end
2829

src/GNNGraphs/generate.jl

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,23 @@ function rand_graph(n::Integer, m::Integer; bidirected = true, seed = -1, edge_w
4848
end
4949

5050
"""
51-
rand_heterograph(n, m; seed=-1, kws...)
51+
rand_heterograph(n, m; seed=-1, bidirected=false, kws...)
5252
5353
Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges
54-
specified by `n` and `m` respectively.
55-
`n` and `m` can be any iterable of pairs.
54+
specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs
55+
specifing node/edge types and their numbers.
5656
5757
Use a `seed > 0` for reproducibility.
5858
59+
Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge.
60+
Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)`
61+
will be generated.
62+
5963
Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor.
6064
6165
# Examples
6266
6367
```julia-repl
64-
65-
6668
julia> g = rand_heterograph((:user => 10, :movie => 20),
6769
(:user, :rate, :movie) => 30)
6870
GNNHeteroGraph:
@@ -76,19 +78,72 @@ function rand_heterograph end
7678
rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...)
7779

7880
function rand_heterograph(n::NDict, m::EDict; bidirected = false, seed = -1, kws...)
79-
@assert !bidirected "Bidirected graphs not supported yet."
8081
rng = seed > 0 ? MersenneTwister(seed) : Random.GLOBAL_RNG
82+
if bidirected
83+
return _rand_bidirected_heterograph(rng, n, m; kws...)
84+
end
8185
graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m))
8286
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
8387
end
8488

89+
function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...)
90+
for k in keys(m)
91+
if reverse(k) keys(m)
92+
@assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs."
93+
else
94+
m[reverse(k)] = m[k]
95+
end
96+
end
97+
graphs = Dict{EType, Tuple{Vector{Int}, Vector{Int}, Nothing}}()
98+
for k in keys(m)
99+
reverse(k) keys(graphs) && continue
100+
s, t, val = _rand_edges(rng, (n[k[1]], n[k[3]]), m[k])
101+
graphs[k] = s, t, val
102+
graphs[reverse(k)] = t, s, val
103+
end
104+
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
105+
end
106+
85107
function _rand_edges(rng, (n1, n2), m)
86108
idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false)
87109
s, t = edge_decoding(idx, n1, n2)
88110
val = nothing
89111
return s, t, val
90112
end
91113

114+
"""
115+
rand_bipartite_heterograph(n1, n2, m; [bidirected, seed, node_t, edge_t, kws...])
116+
rand_bipartite_heterograph((n1, n2), m; ...)
117+
rand_bipartite_heterograph((n1, n2), (m1, m2); ...)
118+
119+
Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges
120+
specified by `n1`, `n2` and `m1` and `m2` respectively.
121+
122+
See [`rand_heterograph`](@ref) for a more general version.
123+
124+
# Keyword arguments
125+
126+
- `bidirected`: whether to generate a bidirected graph. Default is `true`.
127+
- `seed`: random seed. Default is `-1` (no seed).
128+
- `node_t`: node types. If `bipartite=true`, this should be a tuple of two node types, otherwise it should be a single node type.
129+
- `edge_t`: edge types. If `bipartite=true`, this should be a tuple of two edge types, otherwise it should be a single edge type.
130+
"""
131+
function rand_bipartite_heterograph end
132+
133+
rand_bipartite_heterograph(n1::Int, n2::Int, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...)
134+
135+
rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...)
136+
137+
function rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, (m1, m2)::NTuple{2,Int}; bidirected=true,
138+
node_t = (:A, :B), edge_t = :to, kws...)
139+
if edge_t isa Symbol
140+
edge_t = (edge_t, edge_t)
141+
end
142+
return rand_heterograph(Dict(node_t[1] => n1, node_t[2] => n2),
143+
Dict((node_t[1], edge_t[1], node_t[2]) => m1, (node_t[2], edge_t[2], node_t[1]) => m2);
144+
bidirected, kws...)
145+
end
146+
92147
"""
93148
knn_graph(points::AbstractMatrix,
94149
k::Int;

src/GNNGraphs/gnnheterograph.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11

2-
const EDict{T} = Dict{Tuple{Symbol, Symbol, Symbol}, T}
3-
const NDict{T} = Dict{Symbol, T}
2+
const EType = Tuple{Symbol, Symbol, Symbol}
3+
const NType = Symbol
4+
const EDict{T} = Dict{EType, T}
5+
const NDict{T} = Dict{NType, T}
46

57
"""
68
GNNHeteroGraph(data; [ndata, edata, gdata, num_nodes])
@@ -78,30 +80,35 @@ julia> hg.ndata[:A].x
7880
7981
See also [`GNNGraph`](@ref) for a homogeneous graph type and [`rand_heterograph`](@ref) for a function to generate random heterographs.
8082
"""
81-
struct GNNHeteroGraph
82-
graph::EDict
83+
struct GNNHeteroGraph{T <: Union{COO_T, ADJMAT_T}}
84+
graph::EDict{T}
8385
num_nodes::NDict{Int}
8486
num_edges::EDict{Int}
8587
num_graphs::Int
8688
graph_indicator::Union{Nothing, NDict}
8789
ndata::NDict{DataStore}
8890
edata::EDict{DataStore}
8991
gdata::DataStore
90-
ntypes::Vector{Symbol}
91-
etypes::Vector{Symbol}
92+
ntypes::Vector{NType}
93+
etypes::Vector{EType}
9294
end
9395

9496
@functor GNNHeteroGraph
9597

9698
GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...)
9799

100+
function GNNHeteroGraph(data::Dict; kws...)
101+
all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form (source_type, edge_type, target_type)"))
102+
return GNNHeteroGraph(Dict(k => v for (k, v) in pairs(data)); kws...)
103+
end
104+
98105
function GNNHeteroGraph(data::EDict;
99106
num_nodes = nothing,
100107
graph_indicator = nothing,
101108
graph_type = :coo,
102109
dir = :out,
103-
ndata = NDict{NamedTuple}(),
104-
edata = EDict{NamedTuple}(),
110+
ndata = NDict{DataStore}(),
111+
edata = EDict{DataStore}(),
105112
gdata = (;))
106113
@assert graph_type [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
107114
@assert dir [:in, :out]
@@ -112,7 +119,7 @@ function GNNHeteroGraph(data::EDict;
112119
end
113120

114121
ntypes = union([[k[1] for k in keys(data)]; [k[3] for k in keys(data)]])
115-
etypes = [k[2] for k in keys(data)]
122+
etypes = collect(keys(data))
116123

117124
if graph_type == :coo
118125
graph, num_nodes, num_edges = to_coo(data; num_nodes, dir)

src/GNNGraphs/query.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,19 @@ edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2]
1313

1414
edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][1:2]
1515

16+
""""
17+
edge_index(g::GNNHeteroGraph, edge_t)
18+
19+
Return a tuple containing two vectors, respectively storing the source and target nodes for each edges in `g` of type `edge_t = (:node1_t, :rel, :node2_t)`.
20+
"""
21+
edge_index(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][1:2]
22+
1623
get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3]
1724

1825
get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][3]
1926

27+
get_edge_weight(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][3]
28+
2029
Graphs.edges(g::GNNGraph) = Graphs.Edge.(edge_index(g)...)
2130

2231
Graphs.edgetype(g::GNNGraph) = Graphs.Edge{eltype(g)}
@@ -42,6 +51,31 @@ end
4251

4352
Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i, j] != 0
4453

54+
"""
55+
has_edge(g::GNNHeteroGraph, edge_t, i, j)
56+
57+
Return `true` if there is an edge of type `edge_t` from node `i` to node `j` in `g`.
58+
59+
# Examples
60+
61+
```julia-repl
62+
julia> g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false)
63+
GNNHeteroGraph:
64+
num_nodes: (:A => 2, :B => 2)
65+
num_edges: ((:A, :to, :B) => 4, (:B, :to, :A) => 0)
66+
67+
julia> has_edge(g, (:A,:to,:B), 1, 1)
68+
true
69+
70+
julia> has_edge(g, (:B,:to,:A), 1, 1)
71+
false
72+
```
73+
"""
74+
function Graphs.has_edge(g::GNNHeteroGraph, edge_t::EType, i::Integer, j::Integer)
75+
s, t = edge_index(g, edge_t)
76+
return any((s .== i) .& (t .== j))
77+
end
78+
4579
graph_type_symbol(::GNNGraph{<:COO_T}) = :coo
4680
graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse
4781
graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense

src/GNNGraphs/transform.jl

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function add_self_loops(g::GNNGraph{<:COO_T})
2121
ew = [ew; fill!(similar(ew, n), 1)]
2222
end
2323

24-
GNNGraph((s, t, ew),
24+
return GNNGraph((s, t, ew),
2525
g.num_nodes, length(s), g.num_graphs,
2626
g.graph_indicator,
2727
g.ndata, g.edata, g.gdata)
@@ -32,7 +32,7 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T})
3232
@assert isempty(g.edata)
3333
num_edges = g.num_edges + g.num_nodes
3434
A = A + I
35-
GNNGraph(A,
35+
return GNNGraph(A,
3636
g.num_nodes, num_edges, g.num_graphs,
3737
g.graph_indicator,
3838
g.ndata, g.edata, g.gdata)
@@ -71,7 +71,7 @@ function remove_self_loops(g::GNNGraph{<:ADJMAT_T})
7171
dropzeros!(A)
7272
end
7373
num_edges = numnonzeros(A)
74-
GNNGraph(A,
74+
return GNNGraph(A,
7575
g.num_nodes, num_edges, g.num_graphs,
7676
g.graph_indicator,
7777
g.ndata, g.edata, g.gdata)
@@ -110,7 +110,7 @@ function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +)
110110
edata = _scatter(aggr, edata, idxs, num_edges)
111111
end
112112

113-
GNNGraph((s, t, w),
113+
return GNNGraph((s, t, w),
114114
g.num_nodes, num_edges, g.num_graphs,
115115
g.graph_indicator,
116116
g.ndata, edata, g.gdata)
@@ -137,13 +137,55 @@ function add_edges(g::GNNGraph{<:COO_T},
137137
s = [s; snew]
138138
t = [t; tnew]
139139

140-
GNNGraph((s, t, nothing),
140+
return GNNGraph((s, t, nothing),
141141
g.num_nodes, length(s), g.num_graphs,
142142
g.graph_indicator,
143143
g.ndata, edata, g.gdata)
144144
end
145145

146-
### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges)
146+
function add_edges(g::GNNHeteroGraph{<:COO_T},
147+
edge_t::EType,
148+
snew::AbstractVector{<:Integer},
149+
tnew::AbstractVector{<:Integer};
150+
edata = nothing)
151+
@assert length(snew) == length(tnew)
152+
# TODO remove this constraint
153+
@assert get_edge_weight(g, edge_t) === nothing
154+
155+
edata = normalize_graphdata(edata, default_name = :e, n = length(snew))
156+
g_edata = g.edata |> copy
157+
if !isempty(g.edata)
158+
if haskey(g_edata, edge_t)
159+
g_edata[edge_t] = cat_features(g.edata[edge_t], edata)
160+
else
161+
g_edata[edge_t] = edata
162+
end
163+
end
164+
165+
graph = g.graph |> copy
166+
etypes = g.etypes |> copy
167+
if !haskey(graph, edge_t)
168+
@assert edge_t[1] g.ntypes && edge_t[3] g.ntypes
169+
push!(g.etypes, edge_t)
170+
else
171+
s, t = edge_index(g, edge_t)
172+
snew = [s; snew]
173+
tnew = [t; tnew]
174+
end
175+
graph[edge_t] = (snew, tnew, nothing)
176+
num_edges = g.num_edges |> copy
177+
num_edges[edge_t] = length(graph[edge_t][1])
178+
179+
return GNNHeteroGraph(graph,
180+
g.num_nodes, num_edges, g.num_graphs,
181+
g.graph_indicator,
182+
g.ndata, g_edata, g.gdata,
183+
g.ntypes, etypes)
184+
end
185+
186+
187+
188+
### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable
147189
# function Graphs.add_edge!(g::GNNGraph{<:COO_T}, snew::T, tnew::T; edata=nothing) where T<:Union{Integer, AbstractVector}
148190
# s, t = edge_index(g)
149191
# @assert length(snew) == length(tnew)

test/GNNGraphs/generate.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,20 @@ end
7575
s, t = edge_index(g)
7676
@test (s .> 5) == (t .> 5)
7777
end
78+
79+
@testset "rand_bipartite_heterograph" begin
80+
g = rand_bipartite_heterograph(10, 15, 20)
81+
@test g.num_nodes == Dict(:A => 10, :B => 15)
82+
@test g.num_edges == Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20)
83+
sA, tB = edge_index(g, (:A, :to, :B))
84+
for (s, t) in zip(sA, tB)
85+
@test 1 <= s <= 10
86+
@test 1 <= t <= 15
87+
@test has_edge(g, (:A,:to,:B), s, t)
88+
@test has_edge(g, (:B,:to,:A), t, s)
89+
end
90+
91+
g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false)
92+
@test has_edge(g, (:A,:to,:B), 1, 1)
93+
@test !has_edge(g, (:B,:to,:A), 1, 1)
94+
end

test/GNNGraphs/gnnheterograph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
@test hg.edata == Dict()
1919
@test isempty(hg.gdata)
2020
@test sort(hg.ntypes) == [:A, :B]
21-
@test sort(hg.etypes) == [:rel1, :rel2]
21+
@test sort(hg.etypes) == [(:A, :rel1, :B), (:B, :rel2, :A)]
2222
end
2323

2424
@testset "features" begin

0 commit comments

Comments
 (0)