Skip to content

Commit e7f0efc

Browse files
HeteroGraph implementation (#146)
* update based on Dicts * suggestion * add rand_heterograph * normalize data * simplify construction * cleanup docs * update show graph * fix test * renaming * cl/hetero * fix test
1 parent 637549e commit e7f0efc

File tree

13 files changed

+462
-49
lines changed

13 files changed

+462
-49
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ julia = "1.6"
4747
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
4848
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4949
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
50+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5051
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
5152
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
5253
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5354
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5455

5556
[targets]
56-
test = ["Test", "Adapt", "InlineStrings", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets"]
57+
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets"]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ makedocs(;
2424
"Message Passing" => "messagepassing.md",
2525
"Model Building" => "models.md",
2626
"Datasets" => "datasets.md",
27+
"HeteroGraphs" => "gnnheterograph.md",
2728
"Tutorials" => tutorials,
2829
"API Reference" =>
2930
[

docs/src/gnngraph.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ julia> target = [2,3,1,3,1,2,4,3];
4747

4848
julia> g = GNNGraph(source, target)
4949
GNNGraph:
50-
num_nodes = 4
51-
num_edges = 8
50+
num_nodes: 4
51+
num_edges: 8
5252

5353

5454
julia> @assert g.num_nodes == 4 # number of nodes
@@ -88,7 +88,7 @@ g = rand_graph(10, 60, ndata = rand(Float32, 32, 10))
8888

8989
g.ndata.x # `:x` is the default name for node features
9090

91-
# For convinience, we can access the features through the shortcut
91+
# For convenience, we can access the features through the shortcut
9292
g.x
9393

9494
# You can have multiple feature arrays
@@ -131,8 +131,8 @@ julia> weight = [1.0, 0.5, 2.1, 2.3, 4, 4.1];
131131

132132
julia> g = GNNGraph(source, target, weight)
133133
GNNGraph:
134-
num_nodes = 3
135-
num_edges = 6
134+
num_nodes: 3
135+
num_edges: 6
136136

137137
julia> get_edge_weight(g)
138138
6-element Vector{Float64}:
@@ -221,10 +221,10 @@ using Flux: gpu
221221
g_gpu = g |> gpu
222222
```
223223

224-
## JuliaGraphs/Graphs.jl integration
224+
## Integraton with Graphs.jl integration
225225

226-
Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.jl.
227-
Moreover, `GNNGraph`s can be constructed from `Graphs.Graph` and `Graphs.DiGraph`.
226+
Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl) for querying and analyzing the graph structure.
227+
Moreover, a `GNNGraph` can be easily constructed from a `Graphs.Graph` or a `Graphs.DiGraph`:
228228

229229
```julia
230230
julia> import Graphs
@@ -237,17 +237,17 @@ julia> gu = Graphs.erdos_renyi(10, 20)
237237

238238
# Since GNNGraphs are undirected, the edges are doubled when converting
239239
# to GNNGraph
240-
julia> GNNGraph(gu) # Since GNNGraphs are
240+
julia> GNNGraph(gu)
241241
GNNGraph:
242-
num_nodes = 10
243-
num_edges = 40
242+
num_nodes: 10
243+
num_edges: 40
244244

245245
# A Graphs.jl directed graph
246246
julia> gd = Graphs.erdos_renyi(10, 20, is_directed=true)
247247
{10, 20} directed simple Int64 graph
248248

249249
julia> GNNGraph(gd)
250250
GNNGraph:
251-
num_nodes = 10
252-
num_edges = 20
251+
num_nodes: 10
252+
num_edges: 20
253253
```

docs/src/gnnheterograph.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Heterogeneous Graphs
2+
3+
!!! warning
4+
Heterographs support is still experimental.
5+
The interface could be subject to change in the future.
6+
7+
Heterogeneus graphs (also called heterographs), are graphs where each node has a type,
8+
that we denote with symbols such as `:user` and `:movie`,
9+
and edges also represent different relations identified
10+
by a triple of symbols, `(source_nodes, edge_type, target_nodes)`, as in `(:user, :rate, :movie)`.
11+
12+
Different node/edge types can store different group of features
13+
and this makes heterographs a very flexible modeling tools
14+
and data containers.
15+
16+
In GraphNeuralNetworks.jl heterographs are implemented in
17+
the type [`GNNHeteroGraph`](@ref).
18+
19+
20+
```@docs
21+
GNNHeteroGraph
22+
rand_heterograph
23+
```
24+
25+

examples/node_classification_cora.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,9 @@ function train(; kws...)
4444
dataset = Cora()
4545
classes = dataset.metadata["classes"]
4646
g = mldataset2gnngraph(dataset) |> device
47-
X = g.ndata.features
48-
y = onehotbatch(g.ndata.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged
49-
(; train_mask, val_mask, test_mask) = g.ndata
50-
ytrain = y[:,train_mask]
47+
X = g.features
48+
y = onehotbatch(g.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged
49+
ytrain = y[:, g.train_mask]
5150

5251
nin, nhidden, nout = size(X,1), args.nhidden, length(classes)
5352

@@ -63,8 +62,8 @@ function train(; kws...)
6362

6463
## LOGGING FUNCTION
6564
function report(epoch)
66-
train = eval_loss_accuracy(X, y, train_mask, model, g)
67-
test = eval_loss_accuracy(X, y, test_mask, model, g)
65+
train = eval_loss_accuracy(X, y, g.train_mask, model, g)
66+
test = eval_loss_accuracy(X, y, g.test_mask, model, g)
6867
println("Epoch: $epoch Train: $(train) Test: $(test)")
6968
end
7069

@@ -73,7 +72,7 @@ function train(; kws...)
7372
for epoch in 1:args.epochs
7473
gs = Flux.gradient(ps) do
7574
= model(g, X)
76-
logitcrossentropy(ŷ[:,train_mask], ytrain)
75+
logitcrossentropy(ŷ[:, g.train_mask], ytrain)
7776
end
7877

7978
Flux.Optimise.update!(opt, ps, gs)

src/GNNGraphs/GNNGraphs.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ export GNNGraph,
2222
edge_features,
2323
graph_features
2424

25+
include("gnnheterograph.jl")
26+
export GNNHeteroGraph
27+
2528
include("query.jl")
2629
export adjacency_list,
2730
edge_index,
@@ -60,6 +63,7 @@ export add_nodes,
6063

6164
include("generate.jl")
6265
export rand_graph,
66+
rand_heterograph,
6367
knn_graph,
6468
radius_graph
6569

@@ -75,4 +79,6 @@ include("utils.jl")
7579
include("gatherscatter.jl")
7680
# _gather, _scatter
7781

82+
83+
7884
end #module

src/GNNGraphs/convert.jl

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,55 @@
11
### CONVERT_TO_COO REPRESENTATION ########
22

3-
function to_coo(coo::COO_T; dir=:out, num_nodes=nothing, weighted=true)
3+
function to_coo(data::EDict; num_nodes = nothing, kws...)
4+
graph = EDict{Any}()
5+
_num_nodes = NDict{Int}()
6+
num_edges = EDict{Int}()
7+
for k in keys(data)
8+
d = data[k]
9+
@assert d isa Tuple
10+
if length(d) == 2
11+
d = (d..., nothing)
12+
end
13+
if num_nodes !== nothing
14+
n1 = get(num_nodes, k[1], nothing)
15+
n2 = get(num_nodes, k[3], nothing)
16+
else
17+
n1 = nothing
18+
n2 = nothing
19+
end
20+
g, nnodes, nedges = to_coo(d; hetero=true, num_nodes=(n1,n2), kws...)
21+
graph[k] = g
22+
num_edges[k] = nedges
23+
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
24+
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
25+
end
26+
return graph, _num_nodes, num_edges
27+
end
28+
29+
function to_coo(coo::COO_T; dir=:out, num_nodes=nothing, weighted=true, hetero=false)
430
s, t, val = coo
5-
num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
31+
32+
if isnothing(num_nodes)
33+
ns = maximum(s)
34+
nt = maximum(t)
35+
num_nodes = hetero ? (ns, nt) : max(ns, nt)
36+
elseif num_nodes isa Integer
37+
ns = num_nodes
38+
nt = num_nodes
39+
elseif num_nodes isa Tuple
40+
ns = isnothing(num_nodes[1]) ? maximum(s) : num_nodes[1]
41+
nt = isnothing(num_nodes[2]) ? maximum(t) : num_nodes[2]
42+
num_nodes = (ns, nt)
43+
else
44+
error("Invalid num_nodes $num_nodes")
45+
end
646
@assert isnothing(val) || length(val) == length(s)
747
@assert length(s) == length(t)
848
if !isempty(s)
9-
@assert min(minimum(s), minimum(t)) >= 1
10-
@assert max(maximum(s), maximum(t)) <= num_nodes
49+
@assert minimum(s) >= 1
50+
@assert minimum(t) >= 1
51+
@assert maximum(s) <= ns
52+
@assert maximum(t) <= nt
1153
end
1254
num_edges = length(s)
1355
if !weighted

src/GNNGraphs/generate.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,47 @@ function rand_graph(n::Integer, m::Integer; bidirected=true, seed=-1, kws...)
4545
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed=!bidirected, seed); kws...)
4646
end
4747

48+
"""
49+
rand_heterograph(n, m; seed=-1, kws...)
50+
51+
Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges
52+
specified by `n` and `m` respectively.
53+
`n` and `m` can be any iterable of pairs.
54+
55+
Use a `seed > 0` for reproducibility.
56+
57+
Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor.
58+
59+
# Examples
60+
61+
```juliarepl
62+
63+
64+
julia> g = rand_heterograph((:user => 10, :movie => 20),
65+
(:user, :rate, :movie) => 30)
66+
GNNHeteroGraph:
67+
num_nodes: (:user => 10, :movie => 20)
68+
num_edges: ((:user, :rate, :movie) => 30,)
69+
```
70+
"""
71+
rand_heteropraph
72+
73+
# for generic iterators of pairs
74+
rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...)
75+
76+
function rand_heterograph(n::NDict, m::EDict; bidirected=false, seed=-1, kws...)
77+
@assert !bidirected "Bidirected graphs not supported yet."
78+
rng = seed > 0 ? MersenneTwister(seed) : Random.GLOBAL_RNG
79+
graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m))
80+
return GNNHeteroGraph(graphs; num_nodes=n, kws...)
81+
end
82+
83+
function _rand_edges(rng, (n1, n2), m)
84+
idx = StatsBase.sample(rng, 1:n1*n2, m, replace=false)
85+
s, t = edge_decoding(idx, n1, n2)
86+
val = nothing
87+
return s, t, val
88+
end
4889

4990
"""
5091
knn_graph(points::AbstractMatrix,

src/GNNGraphs/gnngraph.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -230,27 +230,27 @@ function Base.show(io::IO, ::MIME"text/plain", g::GNNGraph)
230230
print(io, "GNNGraph($(g.num_nodes), $(g.num_edges))")
231231
else # if the following block is indented the printing is ruined
232232
print(io, "GNNGraph:
233-
num_nodes = $(g.num_nodes)
234-
num_edges = $(g.num_edges)")
235-
g.num_graphs > 1 && print(io, "\n num_graphs = $(g.num_graphs)")
236-
if !isempty(g.ndata)
237-
print(io, "\n ndata:")
238-
for k in keys(g.ndata)
239-
print(io, "\n $k => $(summary(g.ndata[k]))")
240-
end
241-
end
242-
if !isempty(g.edata)
243-
print(io, "\n edata:")
244-
for k in keys(g.edata)
245-
print(io, "\n $k => $(summary(g.edata[k]))")
246-
end
247-
end
248-
if !isempty(g.gdata)
249-
print(io, "\n gdata:")
250-
for k in keys(g.gdata)
251-
print(io, "\n $k => $(summary(g.gdata[k]))")
252-
end
253-
end
233+
num_nodes: $(g.num_nodes)
234+
num_edges: $(g.num_edges)")
235+
g.num_graphs > 1 && print(io, "\n num_graphs = $(g.num_graphs)")
236+
if !isempty(g.ndata)
237+
print(io, "\n ndata:")
238+
for k in keys(g.ndata)
239+
print(io, "\n $k = $(shortsummary(g.ndata[k]))")
240+
end
241+
end
242+
if !isempty(g.edata)
243+
print(io, "\n edata:")
244+
for k in keys(g.edata)
245+
print(io, "\n $k = $(shortsummary(g.edata[k]))")
246+
end
247+
end
248+
if !isempty(g.gdata)
249+
print(io, "\n gdata:")
250+
for k in keys(g.gdata)
251+
print(io, "\n $k = $(shortsummary(g.gdata[k]))")
252+
end
253+
end
254254
end #else
255255
end
256256

0 commit comments

Comments
 (0)