Skip to content

Commit d952fc5

Browse files
heterograph guide (#310)
1 parent ee8be00 commit d952fc5

File tree

5 files changed

+172
-25
lines changed

5 files changed

+172
-25
lines changed

docs/src/heterograph.md

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,95 @@
33
Heterogeneus graphs (also called heterographs), are graphs where each node has a type,
44
that we denote with symbols such as `:user` and `:movie`,
55
and edges also represent different relations identified
6-
by a triple of symbols, `(source_nodes, edge_type, target_nodes)`, as in `(:user, :rate, :movie)`.
6+
by a triplet of symbols, `(source_node_type, edge_type, target_node_type)`, as in `(:user, :rate, :movie)`.
77

88
Different node/edge types can store different group of features
99
and this makes heterographs a very flexible modeling tools
10-
and data containers.
11-
12-
In GraphNeuralNetworks.jl heterographs are implemented in
10+
and data containers. In GraphNeuralNetworks.jl heterographs are implemented in
1311
the type [`GNNHeteroGraph`](@ref).
1412

13+
14+
## Creating a Heterograph
15+
16+
A heterograph can be created by passing pairs of edge types and data to the constructor.
17+
```julia-repl
18+
julia> g = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]))
19+
GNNHeteroGraph:
20+
num_nodes: Dict(:movie => 13, :user => 3)
21+
num_edges: Dict((:user, :rate, :movie) => 4)
22+
```
23+
New relations, possibly with new node types, can be added with the function [`add_edges`](@ref).
24+
```julia-repl
25+
julia> g = add_edges(g, (:user, :like, :actor) => ([1,2,3,3,3], [3,5,1,9,4]))
26+
GNNHeteroGraph:
27+
num_nodes: Dict(:actor => 9, :movie => 13, :user => 3)
28+
num_edges: Dict((:user, :like, :actor) => 5, (:user, :rate, :movie) => 4)
29+
```
30+
See [`rand_heterograph`](@ref), [`rand_bipartite_heterograph`](@ref)
31+
for generating random heterographs.
32+
33+
```julia-repl
34+
julia> g = rand_bipartite_heterograph((10, 15), 20)
35+
GNNHeteroGraph:
36+
num_nodes: Dict(:A => 10, :B => 15)
37+
num_edges: Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20)
38+
```
39+
40+
## Basic Queries
41+
42+
Basic queries are similar to those for homogeneous graphs:
43+
```julia-repl
44+
julia> g = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]))
45+
GNNHeteroGraph:
46+
num_nodes: Dict(:movie => 13, :user => 3)
47+
num_edges: Dict((:user, :rate, :movie) => 4)
48+
49+
julia> g.num_nodes
50+
Dict{Symbol, Int64} with 2 entries:
51+
:user => 3
52+
:movie => 13
53+
54+
julia> g.num_edges
55+
Dict{Tuple{Symbol, Symbol, Symbol}, Int64} with 1 entry:
56+
(:user, :rate, :movie) => 4
57+
58+
# source and target node for a given relation
59+
julia> edge_index(g, (:user, :rate, :movie))
60+
([1, 1, 2, 3], [7, 13, 5, 7])
61+
62+
# node types
63+
julia> g.ntypes
64+
2-element Vector{Symbol}:
65+
:user
66+
:movie
67+
68+
# edge types
69+
julia> g.etypes
70+
1-element Vector{Tuple{Symbol, Symbol, Symbol}}:
71+
(:user, :rate, :movie)
72+
```
73+
74+
## Data Features
75+
76+
Node, edge, and graph features can be added at constuction time or later using:
77+
```julia-repl
78+
# equivalent to g.ndata[:user][:x] = ...
79+
julia> g[:user].x = rand(Float32, 64, 3);
80+
81+
julia> g[:movie].z = rand(Float32, 64, 13);
82+
83+
# equivalent to g.edata[(:user, :rate, :movie)][:e] = ...
84+
julia> g[:user, :rate, :movie].e = rand(Float32, 64, 4);
85+
86+
julia> g
87+
GNNHeteroGraph:
88+
num_nodes: Dict(:movie => 13, :user => 3)
89+
num_edges: Dict((:user, :rate, :movie) => 4)
90+
ndata:
91+
:movie => DataStore(z = [64×13 Matrix{Float32}])
92+
:user => DataStore(x = [64×3 Matrix{Float32}])
93+
edata:
94+
(:user, :rate, :movie) => DataStore(e = [64×4 Matrix{Float32}])
95+
```
96+
97+

src/GNNGraphs/gnnheterograph.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ function GNNHeteroGraph(data::EDict;
145145
end
146146

147147
function show_sorted_dict(io::IO, d::Dict, compact::Bool)
148-
if compact
148+
# if compact
149149
print(io, "Dict")
150-
end
150+
# end
151151
print(io, "(")
152152
if !isempty(d)
153153
_keys = sort!(collect(keys(d)))
@@ -156,6 +156,9 @@ function show_sorted_dict(io::IO, d::Dict, compact::Bool)
156156
end
157157
print(io, "$(_str(_keys[end])) => $(d[_keys[end]])")
158158
end
159+
# if length(d) == 1
160+
# print(io, ",")
161+
# end
159162
print(io, ")")
160163
end
161164

@@ -180,15 +183,17 @@ function Base.show(io::IO, ::MIME"text/plain", g::GNNHeteroGraph)
180183
print(io, "\n num_edges: ")
181184
show_sorted_dict(io, g.num_edges, false)
182185
g.num_graphs > 1 && print(io, "\n num_graphs: $(g.num_graphs)")
183-
if !isempty(g.ndata)
186+
if !isempty(g.ndata) && !all(isempty, values(g.ndata))
184187
print(io, "\n ndata:")
185188
for k in sort(collect(keys(g.ndata)))
189+
isempty(g.ndata[k]) && continue
186190
print(io, "\n\t", _str(k), " => $(shortsummary(g.ndata[k]))")
187191
end
188192
end
189-
if !isempty(g.edata)
193+
if !isempty(g.edata) && !all(isempty, values(g.edata))
190194
print(io, "\n edata:")
191195
for k in sort(collect(keys(g.edata)))
196+
isempty(g.edata[k]) && continue
192197
print(io, "\n\t$k => $(shortsummary(g.edata[k]))")
193198
end
194199
end
@@ -271,20 +276,12 @@ end
271276

272277
@non_differentiable _ntypes_from_edges(::Any...)
273278

274-
275279
function Base.getindex(g::GNNHeteroGraph, node_t::NType)
276-
if !haskey(g.ndata, node_t) && node_t in g.ntypes
277-
g.ndata[node_t] = DataStore(g.num_nodes[node_t])
278-
end
279280
return g.ndata[node_t]
280281
end
281282

282-
Base.setindex!(g::GNNHeteroGraph, node_t::NType, x) = g.ndata[node_t] = x
283+
Base.getindex(g::GNNHeteroGraph, n1_t::Symbol, rel::Symbol, n2_t::Symbol) = g[(n1_t, rel, n2_t)]
283284

284285
function Base.getindex(g::GNNHeteroGraph, edge_t::EType)
285-
if !haskey(g.edata, node_t) && edge_t in g.etypes
286-
g.ndata[node_t] = DataStore(g.num_edges[edge_t])
287-
end
288286
return g.edata[edge_t]
289287
end
290-
Base.setindex!(g::GNNHeteroGraph, edge_t::EType, x) = g.edata[edge_t] = x

src/GNNGraphs/transform.jl

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ end
120120
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
121121
122122
Add to graph `g` the edges with source nodes `s` and target nodes `t`.
123-
Optionally, pass the features `edata` for the new edges.
123+
Optionally, pass the features `edata` for the new edges.
124+
Returns a new graph sharing part of the underlying data with `g`.
124125
"""
125126
function add_edges(g::GNNGraph{<:COO_T},
126127
snew::AbstractVector{<:Integer},
@@ -143,14 +144,33 @@ function add_edges(g::GNNGraph{<:COO_T},
143144
g.ndata, edata, g.gdata)
144145
end
145146

147+
148+
"""
149+
add_edges(g::GNNHeteroGraph, edge_t, s, t; [edata, num_nodes])
150+
add_edges(g::GNNHeteroGraph, edge_t => (s, t); [edata, num_nodes])
151+
152+
Add to heterograph `g` the releation of type `edge_t` with source node vector `s` and target node vector `t`.
153+
Optionally, pass the features `edata` for the new edges.
154+
`edge_t` is a triplet of symbols `(srctype, etype, dsttype)`.
155+
156+
If the edge type is not already present in the graph, it is added. If it involves new node types, they are added to the graph as well.
157+
In this case, a dictionary or named tuple of `num_nodes` can be passed to specify the number of nodes of the new types,
158+
otherwise the number of nodes is inferred from the maximum node id in `s` and `t`.
159+
"""
160+
add_edges(g::GNNHeteroGraph{<:COO_T}, data::Pair{EType, <:Tuple}; kws...) = add_edges(g, data.first, data.second...; kws...)
161+
146162
function add_edges(g::GNNHeteroGraph{<:COO_T},
147163
edge_t::EType,
148164
snew::AbstractVector{<:Integer},
149165
tnew::AbstractVector{<:Integer};
150-
edata = nothing)
166+
edata = nothing,
167+
num_nodes = Dict{Symbol,Int}())
151168
@assert length(snew) == length(tnew)
169+
is_existing_rel = haskey(g.graph, edge_t)
152170
# TODO remove this constraint
153-
@assert get_edge_weight(g, edge_t) === nothing
171+
if is_existing_rel
172+
@assert get_edge_weight(g, edge_t) === nothing
173+
end
154174

155175
edata = normalize_graphdata(edata, default_name = :e, n = length(snew))
156176
g_edata = g.edata |> copy
@@ -164,23 +184,41 @@ function add_edges(g::GNNHeteroGraph{<:COO_T},
164184

165185
graph = g.graph |> copy
166186
etypes = g.etypes |> copy
167-
if !haskey(graph, edge_t)
168-
@assert edge_t[1] g.ntypes && edge_t[3] g.ntypes
169-
push!(g.etypes, edge_t)
187+
ntypes = g.ntypes |> copy
188+
_num_nodes = g.num_nodes |> copy
189+
ndata = g.ndata |> copy
190+
if !is_existing_rel
191+
for (node_t, st) in [(edge_t[1], snew), (edge_t[3], tnew)]
192+
if node_t ntypes
193+
push!(ntypes, node_t)
194+
if haskey(num_nodes, node_t)
195+
_num_nodes[node_t] == num_nodes[node_t]
196+
else
197+
_num_nodes[node_t] = maximum(st)
198+
end
199+
ndata[node_t] = DataStore(_num_nodes[node_t])
200+
end
201+
end
202+
push!(etypes, edge_t)
170203
else
171204
s, t = edge_index(g, edge_t)
172205
snew = [s; snew]
173206
tnew = [t; tnew]
174207
end
208+
@assert maximum(snew) <= _num_nodes[edge_t[1]]
209+
@assert maximum(tnew) <= _num_nodes[edge_t[3]]
210+
@assert minimum(snew) >= 1
211+
@assert minimum(tnew) >= 1
212+
175213
graph[edge_t] = (snew, tnew, nothing)
176214
num_edges = g.num_edges |> copy
177215
num_edges[edge_t] = length(graph[edge_t][1])
178216

179217
return GNNHeteroGraph(graph,
180-
g.num_nodes, num_edges, g.num_graphs,
218+
_num_nodes, num_edges, g.num_graphs,
181219
g.graph_indicator,
182220
g.ndata, g_edata, g.gdata,
183-
g.ntypes, etypes)
221+
ntypes, etypes)
184222
end
185223

186224

test/GNNGraphs/gnnheterograph.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,20 @@ end
3434
@test size(hg.ndata[:B].y) == (4, 20)
3535
@test size(hg.edata[(:A, :rel1, :B)].e) == (5, 30)
3636
@test hg.gdata == DataStore(u = 1)
37+
38+
end
39+
40+
@testset "indexing syntax" begin
41+
g = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]))
42+
g[:movie].z = rand(Float32, 64, 13);
43+
g[:user, :rate, :movie].e = rand(Float32, 64, 4);
44+
g[:user].x = rand(Float32, 64, 3);
45+
@test size(g.ndata[:user].x) == (64, 3)
46+
@test size(g.ndata[:movie].z) == (64, 13)
47+
@test size(g.edata[(:user, :rate, :movie)].e) == (64, 4)
3748
end
3849

50+
3951
@testset "simplified constructor" begin
4052
hg = rand_heterograph((:A => 10, :B => 20),
4153
((:A, :rel1, :B) => 30, (:B, :rel2, :A) => 10),

test/GNNGraphs/transform.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,23 @@ end
121121
@test has_edge(hg, (:B,:to,:A), 1, 2)
122122
@test !has_edge(hg, (:B,:to,:A), 2, 1)
123123
@test !has_edge(hg, (:B,:to,:A), 2, 2)
124+
125+
@testset "new nodes" begin
126+
hg = rand_bipartite_heterograph((2, 2), 3)
127+
hg = add_edges(hg, (:C,:rel,:B) => ([1, 3], [1,2]))
128+
@test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3)
129+
@test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2)
130+
s, t = edge_index(hg, (:C,:rel,:B))
131+
@test s == [1, 3]
132+
@test t == [1, 2]
133+
134+
hg = add_edges(hg, (:D,:rel,:F) => ([1, 3], [1,2]))
135+
@test hg.num_nodes == Dict(:A => 2, :B => 2, :C => 3, :D => 3, :F => 2)
136+
@test hg.num_edges == Dict((:A,:to,:B) => 3, (:B,:to,:A) => 3, (:C,:rel,:B) => 2, (:D,:rel,:F) => 2)
137+
s, t = edge_index(hg, (:D,:rel,:F))
138+
@test s == [1, 3]
139+
@test t == [1, 2]
140+
end
124141
end
125142
end
126143
end

0 commit comments

Comments
 (0)