Skip to content

Commit 713996c

Browse files
improve add_edges for graph and heterographs (#337)
* improve add_edges for graph and heterographs * constructor test * jldoctest * add tests * fix tests
1 parent 17a1e7f commit 713996c

17 files changed

+229
-95
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1717

1818
[compat]
1919
DemoCards = "0.5.0"
20+
Documenter = "0.27"

docs/src/heterograph.md

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# Heterogeneous Graphs
22

33
Heterogeneous graphs (also called heterographs), are graphs where each node has a type,
4-
that we denote with symbols such as `:user` and `:movie`,
5-
and edges also represent different relations identified
6-
by a triplet of symbols, `(source_node_type, edge_type, target_node_type)`, as in `(:user, :rate, :movie)`.
4+
that we denote with symbols such as `:user` and `:movie`.
5+
Also edges have a type, such as `:rate` or `:like`, and they can connect nodes of different types. We call a triplet `(source_node_type, edge_type, target_node_type)` the type of a *relation*, e.g. `(:user, :rate, :movie)`.
76

87
Different node/edge types can store different groups of features
98
and this makes heterographs a very flexible modeling tools
@@ -13,15 +12,21 @@ the type [`GNNHeteroGraph`](@ref).
1312

1413
## Creating a Heterograph
1514

16-
A heterograph can be created by passing pairs of edge types and data to the constructor.
17-
```julia-repl
15+
A heterograph can be created by passing pairs of relation type and data to the constructor.
16+
```jldoctest
17+
julia> g = GNNHeteroGraph((:user, :like, :actor) => ([1,2,2,3], [1,3,2,9]),
18+
(:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]))
19+
GNNHeteroGraph:
20+
num_nodes: Dict(:actor => 9, :movie => 13, :user => 3)
21+
num_edges: Dict((:user, :like, :actor) => 4, (:user, :rate, :movie) => 4)
22+
1823
julia> g = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]))
1924
GNNHeteroGraph:
2025
num_nodes: Dict(:movie => 13, :user => 3)
2126
num_edges: Dict((:user, :rate, :movie) => 4)
2227
```
2328
New relations, possibly with new node types, can be added with the function [`add_edges`](@ref).
24-
```julia-repl
29+
```jldoctest
2530
julia> g = add_edges(g, (:user, :like, :actor) => ([1,2,3,3,3], [3,5,1,9,4]))
2631
GNNHeteroGraph:
2732
num_nodes: Dict(:actor => 9, :movie => 13, :user => 3)
@@ -30,7 +35,7 @@ GNNHeteroGraph:
3035
See [`rand_heterograph`](@ref), [`rand_bipartite_heterograph`](@ref)
3136
for generating random heterographs.
3237

33-
```julia-repl
38+
```jldoctest
3439
julia> g = rand_bipartite_heterograph((10, 15), 20)
3540
GNNHeteroGraph:
3641
num_nodes: Dict(:A => 10, :B => 15)
@@ -40,7 +45,7 @@ GNNHeteroGraph:
4045
## Basic Queries
4146

4247
Basic queries are similar to those for homogeneous graphs:
43-
```julia-repl
48+
```jldoctest
4449
julia> g = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]))
4550
GNNHeteroGraph:
4651
num_nodes: Dict(:movie => 13, :user => 3)
@@ -74,7 +79,7 @@ julia> g.etypes
7479
## Data Features
7580

7681
Node, edge, and graph features can be added at construction time or later using:
77-
```julia-repl
82+
```jldoctest
7883
# equivalent to g.ndata[:user][:x] = ...
7984
julia> g[:user].x = rand(Float32, 64, 3);
8085
@@ -96,7 +101,7 @@ GNNHeteroGraph:
96101

97102
## Batching
98103
Similarly to graphs, also heterographs can be batched together.
99-
```julia-repl
104+
```jldoctest
100105
julia> gs = [rand_bipartite_heterograph((5, 10), 20) for _ in 1:32];
101106
102107
julia> Flux.batch(gs)
@@ -108,7 +113,7 @@ GNNHeteroGraph:
108113
Batching is automatically performed by the [`DataLoader`](@ref) iterator
109114
when the `collate` option is set to `true`.
110115

111-
```julia-repl
116+
```jldoctest
112117
using Flux: DataLoader
113118
114119
data = [rand_bipartite_heterograph((5, 10), 20,

docs/src/temporalgraph.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Temporal Graphs are graphs with time varying topologies and node features. In Gr
66

77
A temporal graph can be created by passing a list of snapshots to the constructor. Each snapshot is a [`GNNGraph`](@ref).
88

9-
```julia-repl
9+
```jldoctest
1010
julia> snapshots = [rand_graph(10,20) for i in 1:5];
1111
1212
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
@@ -18,14 +18,14 @@ TemporalSnapshotsGNNGraph:
1818

1919
A new temporal graph can be created by adding or removing snapshots to an existing temporal graph.
2020

21-
```julia-repl
21+
```jldoctest
2222
julia> new_tg = add_snapshot(tg, 3, rand_graph(10, 16)) # add a new snapshot at time 3
2323
TemporalSnapshotsGNNGraph:
2424
num_nodes: [10, 10, 10, 10, 10, 10]
2525
num_edges: [20, 20, 16, 20, 20, 20]
2626
num_snapshots: 6
2727
```
28-
```julia-repl
28+
```jldoctest
2929
julia> snapshots = [rand_graph(10,20), rand_graph(10,14), rand_graph(10,22)];
3030
3131
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
@@ -43,7 +43,7 @@ TemporalSnapshotsGNNGraph:
4343

4444
See [`rand_temporal_radius_graph`](@ref) and ['rand_temporal_hyperbolic_graph'](@ref) for generating random temporal graphs.
4545

46-
```julia-repl
46+
```jldoctest
4747
julia> tg = rand_temporal_radius_graph(10, 3, 0.1, 0.5)
4848
TemporalSnapshotsGNNGraph:
4949
num_nodes: [10, 10, 10]
@@ -54,7 +54,7 @@ TemporalSnapshotsGNNGraph:
5454
## Basic Queries
5555

5656
Basic queries are similar to those for [`GNNGraph`](@ref)s:
57-
```julia-repl
57+
```jldoctest
5858
julia> snapshots = [rand_graph(10,20), rand_graph(10,14), rand_graph(10,22)];
5959
6060
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
@@ -94,7 +94,7 @@ GNNGraph:
9494

9595
Node, edge, and graph features can be added at construction time or later using:
9696

97-
```julia-repl
97+
```jldoctest
9898
julia> snapshots = [rand_graph(10,20; ndata = rand(3,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(5,10))]; # node features at construction time
9999
100100
julia> tg = TemporalSnapshotsGNNGraph(snapshots);

src/GNNGraphs/datastore.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ A container for feature arrays. The optional argument `n` enforces that
88
At construction time, the `data` can be provided as any iterables of pairs
99
of symbols and arrays or as keyword arguments:
1010
11-
```julia-repl
11+
```jldoctest
1212
julia> ds = DataStore(3, x = rand(2, 3), y = rand(3))
1313
DataStore(3) with 2 elements:
1414
y = 3-element Vector{Float64}
@@ -35,7 +35,7 @@ DataStore() with 2 elements:
3535
The `DataStore` has an interface similar to both dictionaries and named tuples.
3636
Arrays can be accessed and added using either the indexing or the property syntax:
3737
38-
```julia-repl
38+
```jldoctest
3939
julia> ds = DataStore(x = ones(2, 3), y = zeros(3))
4040
DataStore() with 2 elements:
4141
y = 3-element Vector{Float64}
@@ -57,7 +57,7 @@ The `DataStore` can be iterated over, and the keys and values can be accessed
5757
using `keys(ds)` and `values(ds)`. `map(f, ds)` applies the function `f`
5858
to each feature array:
5959
60-
```julia-repl
60+
```jldoctest
6161
julia> ds = DataStore(a = zeros(2), b = zeros(2));
6262
6363
julia> ds2 = map(x -> x .+ 1, ds)
@@ -142,6 +142,8 @@ function Base.show(io::IO, ds::DataStore)
142142
for (k, v) in getdata(ds)
143143
print(io, "\n $(k) = $(summary(v))")
144144
end
145+
else
146+
print(io, " with no elements")
145147
end
146148
end
147149

src/GNNGraphs/generate.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructo
1616
1717
# Examples
1818
19-
```juliarepl
19+
```jldoctest
2020
julia> g = rand_graph(5, 4, bidirected=false)
2121
GNNGraph:
2222
num_nodes = 5
@@ -64,7 +64,7 @@ Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) cons
6464
6565
# Examples
6666
67-
```julia-repl
67+
```jldoctest
6868
julia> g = rand_heterograph((:user => 10, :movie => 20),
6969
(:user, :rate, :movie) => 30)
7070
GNNHeteroGraph:
@@ -170,7 +170,7 @@ to its `k` closest `points`.
170170
171171
# Examples
172172
173-
```juliarepl
173+
```jldoctest
174174
julia> n, k = 10, 3;
175175
176176
julia> x = rand(3, n);
@@ -251,7 +251,7 @@ to its neighbors within a given distance `r`.
251251
252252
# Examples
253253
254-
```juliarepl
254+
```jldoctest
255255
julia> n, r = 10, 0.75;
256256
257257
julia> x = rand(3, n);
@@ -331,7 +331,7 @@ If a point happens to move outside the boundary, its position is updated as if i
331331
332332
# Example
333333
334-
```julia-repl
334+
```jldoctest
335335
julia> n, snaps, s, r = 10, 5, 0.1, 1.5;
336336
337337
julia> tg = rand_temporal_radius_graph(n,snaps,s,r) # complete graph at each snapshot
@@ -403,7 +403,7 @@ First, the positions of the nodes are generated with a quasi-uniform distributio
403403
404404
# Example
405405
406-
```julia-repl
406+
```jldoctest
407407
julia> n, snaps, α, R, speed, ζ = 10, 5, 1.0, 4.0, 0.1, 1.0;
408408
409409
julia> thg = rand_temporal_hyperbolic_graph(n, snaps; α, R, speed, ζ)

src/GNNGraphs/gnngraph.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ function GNNGraph(data::D;
153153
ndata, edata, gdata)
154154
end
155155

156+
GNNGraph(; kws...) = GNNGraph(0; kws...)
157+
156158
function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer}
157159
s, t = T[], T[]
158160
return GNNGraph(s, t; num_nodes, kws...)

src/GNNGraphs/gnnheterograph.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,26 @@ const NDict{T} = Dict{NType, T}
66

77
"""
88
GNNHeteroGraph(data; [ndata, edata, gdata, num_nodes])
9+
GNNHeteroGraph(pairs...; [ndata, edata, gdata, num_nodes])
910
1011
A type representing a heterogeneous graph structure.
1112
It is similar to [`GNNGraph`](@ref) but nodes and edges are of different types.
1213
1314
# Constructor Arguments
1415
15-
- `data`: A dictionary or an iterable object that maps (source_type, edge_type, target_type)
16-
triples to (source, target) index vectors.
16+
- `data`: A dictionary or an iterable object that maps `(source_type, edge_type, target_type)`
17+
triples to `(source, target)` index vectors (or to `(source, target, weight)` if also edge weights are present).
18+
- `pairs`: Passing multiple relations as pairs is equivalent to passing `data=Dict(pairs...)`.
1719
- `ndata`: Node features. A dictionary of arrays or named tuple of arrays.
1820
The size of the last dimension of each array must be given by `g.num_nodes`.
19-
- `edata`: Edge features. A dictionary of arrays or named tuple of arrays.
20-
The size of the last dimension of each array must be given by `g.num_edges`.
21-
- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`.
21+
- `edata`: Edge features. A dictionary of arrays or named tuple of arrays. Default `nothing`.
22+
The size of the last dimension of each array must be given by `g.num_edges`. Default `nothing`.
23+
- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`. Default `nothing`.
2224
- `num_nodes`: The number of nodes for each type. If not specified, inferred from `data`. Default `nothing`.
2325
2426
# Fields
2527
26-
- `graph`: A dictionary that maps `(source_type, edge_type, target_type)`` triples to (source, target) index vectors.
28+
- `graph`: A dictionary that maps (source_type, edge_type, target_type) triples to (source, target) index vectors.
2729
- `num_nodes`: The number of nodes for each type.
2830
- `num_edges`: The number of edges for each type.
2931
- `ndata`: Node features.
@@ -41,15 +43,15 @@ julia> nA, nB = 10, 20;
4143
4244
julia> num_nodes = Dict(:A => nA, :B => nB);
4345
44-
julia> edges1 = rand(1:nA, 20), rand(1:nB, 20)
46+
julia> edges1 = (rand(1:nA, 20), rand(1:nB, 20))
4547
([4, 8, 6, 3, 4, 7, 2, 7, 3, 2, 3, 4, 9, 4, 2, 9, 10, 1, 3, 9], [6, 4, 20, 8, 16, 7, 12, 16, 5, 4, 6, 20, 11, 19, 17, 9, 12, 2, 18, 12])
4648
47-
julia> edges2 = rand(1:nB, 30), rand(1:nA, 30)
49+
julia> edges2 = (rand(1:nB, 30), rand(1:nA, 30))
4850
([17, 5, 2, 4, 5, 3, 8, 7, 9, 7 … 19, 8, 20, 7, 16, 2, 9, 15, 8, 13], [1, 1, 3, 1, 1, 3, 2, 7, 4, 4 … 7, 10, 6, 3, 4, 9, 1, 5, 8, 5])
4951
50-
julia> eindex = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2);
52+
julia> data = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2);
5153
52-
julia> hg = GNNHeteroGraph(eindex; num_nodes)
54+
julia> hg = GNNHeteroGraph(data; num_nodes)
5355
GNNHeteroGraph:
5456
num_nodes: (:A => 10, :B => 20)
5557
num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)
@@ -63,7 +65,7 @@ Dict{Tuple{Symbol, Symbol, Symbol}, Int64} with 2 entries:
6365
julia> ndata = Dict(:A => (x = rand(2, nA), y = rand(3, num_nodes[:A])),
6466
:B => rand(10, nB));
6567
66-
julia> hg = GNNHeteroGraph(eindex; num_nodes, ndata)
68+
julia> hg = GNNHeteroGraph(data; num_nodes, ndata)
6769
GNNHeteroGraph:
6870
num_nodes: (:A => 10, :B => 20)
6971
num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)
@@ -96,9 +98,10 @@ end
9698
@functor GNNHeteroGraph
9799

98100
GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...)
101+
GNNHeteroGraph(data::Pair...; kws...) = GNNHeteroGraph(Dict(data...); kws...)
99102

100103
function GNNHeteroGraph(data::Dict; kws...)
101-
all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form (source_type, edge_type, target_type)"))
104+
all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form `(source_type, edge_type, target_type)`"))
102105
return GNNHeteroGraph(Dict([k => v for (k, v) in pairs(data)]...); kws...)
103106
end
104107

src/GNNGraphs/query.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Return `true` if there is an edge of type `edge_t` from node `i` to node `j` in
6262
6363
# Examples
6464
65-
```julia-repl
65+
```jldoctest
6666
julia> g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false)
6767
GNNHeteroGraph:
6868
num_nodes: (:A => 2, :B => 2)

src/GNNGraphs/temporalsnapshotsgnngraph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Return a `TemporalSnapshotsGNNGraph` created starting from `tg` by adding the sn
7777
7878
# Examples
7979
80-
```julia-repl
80+
```jldoctest
8181
julia> using GraphNeuralNetworks
8282
8383
julia> snapshots = [rand_graph(10, 20) for i in 1:5];
@@ -137,7 +137,7 @@ Return a [`TemporalSnapshotsGNNGraph`](@ref) created starting from `tg` by remov
137137
138138
# Examples
139139
140-
```julia-repl
140+
```jldoctest
141141
julia> using GraphNeuralNetworks
142142
143143
julia> snapshots = [rand_graph(10,20), rand_graph(10,14), rand_graph(10,22)];

0 commit comments

Comments
 (0)