Skip to content

Commit 6b58b75

Browse files
rng instead of seed for rand_graph (#482)
* rng instead of seed for rand_graph * add tests * fix tests * rand_bipartite * more * relu -> tanh in tests
1 parent 3ce025b commit 6b58b75

File tree

16 files changed

+256
-114
lines changed

16 files changed

+256
-114
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ Manifest.toml
99
.vscode
1010
LocalPreferences.toml
1111
.DS_Store
12-
docs/src/democards/gridtheme.css
12+
docs/src/democards/gridtheme.css
13+
test.jl

GNNGraphs/src/abstracttypes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
2+
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: Union{Nothing, AbstractVector}}
33
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
44
const ADJMAT_T = AbstractMatrix
55
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T

GNNGraphs/src/convert.jl

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,24 @@ function to_coo(data::EDict; num_nodes = nothing, kws...)
44
graph = EDict{COO_T}()
55
_num_nodes = NDict{Int}()
66
num_edges = EDict{Int}()
7-
if !isempty(data)
8-
for k in keys(data)
9-
d = data[k]
10-
@assert d isa Tuple
11-
if length(d) == 2
12-
d = (d..., nothing)
13-
end
14-
if num_nodes !== nothing
15-
n1 = get(num_nodes, k[1], nothing)
16-
n2 = get(num_nodes, k[3], nothing)
17-
else
18-
n1 = nothing
19-
n2 = nothing
20-
end
21-
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
22-
graph[k] = g
23-
num_edges[k] = nedges
24-
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
25-
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
7+
for k in keys(data)
8+
d = data[k]
9+
@assert d isa Tuple
10+
if length(d) == 2
11+
d = (d..., nothing)
2612
end
27-
graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types
13+
if num_nodes !== nothing
14+
n1 = get(num_nodes, k[1], nothing)
15+
n2 = get(num_nodes, k[3], nothing)
16+
else
17+
n1 = nothing
18+
n2 = nothing
19+
end
20+
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
21+
graph[k] = g
22+
num_edges[k] = nedges
23+
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
24+
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
2825
end
2926
return graph, _num_nodes, num_edges
3027
end

GNNGraphs/src/generate.jl

Lines changed: 84 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
rand_graph(n, m; bidirected=true, seed=-1, edge_weight = nothing, kws...)
2+
rand_graph([rng,] n, m; bidirected=true, edge_weight = nothing, kws...)
33
44
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges.
55
@@ -10,7 +10,7 @@ In any case, the output graph will contain no self-loops or multi-edges.
1010
A vector can be passed as `edge_weight`. Its length has to be equal to `m`
1111
in the directed case, and `m÷2` in the bidirected one.
1212
13-
Use a `seed > 0` for reproducibility.
13+
Pass a random number generator as the first argument to make the generation reproducible.
1414
1515
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.
1616
@@ -36,25 +36,42 @@ GNNGraph:
3636
# Each edge has a reverse
3737
julia> edge_index(g)
3838
([1, 3, 3, 4], [3, 4, 1, 3])
39-
4039
```
4140
"""
42-
function rand_graph(n::Integer, m::Integer; bidirected = true, seed = -1, edge_weight = nothing, kws...)
41+
function rand_graph(n::Integer, m::Integer; seed=-1, kws...)
42+
if seed != -1
43+
Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_graph)
44+
rng = MersenneTwister(seed)
45+
else
46+
rng = Random.default_rng()
47+
end
48+
return rand_graph(rng, n, m; kws...)
49+
end
50+
51+
function rand_graph(rng::AbstractRNG, n::Integer, m::Integer;
52+
bidirected::Bool = true,
53+
edge_weight::Union{AbstractVector, Nothing} = nothing, kws...)
4354
if bidirected
44-
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
55+
@assert iseven(m) lazy"Need even number of edges for bidirected graphs, given m=$m."
56+
s, t, _ = _rand_edges(rng, n, m ÷ 2; directed=false, self_loops=false)
57+
s, t = vcat(s, t), vcat(t, s)
58+
if edge_weight !== nothing
59+
edge_weight = vcat(edge_weight, edge_weight)
60+
end
61+
else
62+
s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false)
4563
end
46-
m2 = bidirected ? m ÷ 2 : m
47-
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed = !bidirected, seed); edge_weight, kws...)
64+
return GNNGraph((s, t, edge_weight); num_nodes=n, kws...)
4865
end
4966

5067
"""
51-
rand_heterograph(n, m; seed=-1, bidirected=false, kws...)
68+
rand_heterograph([rng,] n, m; bidirected=false, kws...)
5269
53-
Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges
70+
Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges
5471
specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs
5572
specifing node/edge types and their numbers.
5673
57-
Use a `seed > 0` for reproducibility.
74+
Pass a random number generator as a first argument to make the generation reproducible.
5875
5976
Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge.
6077
Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)`
@@ -76,17 +93,27 @@ function rand_heterograph end
7693

7794
# for generic iterators of pairs
7895
rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...)
96+
rand_heterograph(rng::AbstractRNG, n, m; kws...) = rand_heterograph(rng, Dict(n), Dict(m); kws...)
7997

80-
function rand_heterograph(n::NDict, m::EDict; bidirected = false, seed = -1, kws...)
81-
rng = seed > 0 ? MersenneTwister(seed) : Random.GLOBAL_RNG
98+
function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...)
99+
if seed != -1
100+
Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph)
101+
rng = MersenneTwister(seed)
102+
else
103+
rng = Random.default_rng()
104+
end
105+
return rand_heterograph(rng, n, m; kws...)
106+
end
107+
108+
function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...)
82109
if bidirected
83110
return _rand_bidirected_heterograph(rng, n, m; kws...)
84111
end
85112
graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m))
86113
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
87114
end
88115

89-
function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...)
116+
function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...)
90117
for k in keys(m)
91118
if reverse(k) keys(m)
92119
@assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs."
@@ -104,43 +131,60 @@ function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...)
104131
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
105132
end
106133

107-
function _rand_edges(rng, (n1, n2), m)
108-
idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false)
109-
s, t = edge_decoding(idx, n1, n2)
110-
val = nothing
111-
return s, t, val
112-
end
113134

114135
"""
115-
rand_bipartite_heterograph(n1, n2, m; [bidirected, seed, node_t, edge_t, kws...])
116-
rand_bipartite_heterograph((n1, n2), m; ...)
117-
rand_bipartite_heterograph((n1, n2), (m1, m2); ...)
136+
rand_bipartite_heterograph([rng,]
137+
(n1, n2), (m12, m21);
138+
bidirected = true,
139+
node_t = (:A, :B),
140+
edge_t = :to,
141+
kws...)
118142
119-
Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges
120-
specified by `n1`, `n2` and `m1` and `m2` respectively.
143+
Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph.
144+
The graph will have two types of nodes, and edges will only connect nodes of different types.
121145
122-
See [`rand_heterograph`](@ref) for a more general version.
146+
The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type.
147+
The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2`
148+
and vice versa.
123149
124-
# Keyword arguments
150+
The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments,
151+
which default to `(:A, :B)` and `:to` respectively.
125152
126-
- `bidirected`: whether to generate a bidirected graph. Default is `true`.
127-
- `seed`: random seed. Default is `-1` (no seed).
128-
- `node_t`: node types. If `bipartite=true`, this should be a tuple of two node types, otherwise it should be a single node type.
129-
- `edge_t`: edge types. If `bipartite=true`, this should be a tuple of two edge types, otherwise it should be a single edge type.
130-
"""
131-
function rand_bipartite_heterograph end
153+
If `bidirected=true` (default), the reverse edge of each edge will be present. In this case
154+
`m12 == m21` is required.
155+
156+
A random number generator can be passed as the first argument to make the generation reproducible.
157+
158+
Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor.
159+
160+
See [`rand_heterograph`](@ref) for a more general version.
161+
162+
# Examples
132163
133-
rand_bipartite_heterograph(n1::Int, n2::Int, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...)
164+
```julia-repl
165+
julia> g = rand_bipartite_heterograph((10, 15), 20)
166+
GNNHeteroGraph:
167+
num_nodes: (:A => 10, :B => 15)
168+
num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20)
134169
135-
rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...)
170+
julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false)
171+
GNNHeteroGraph:
172+
num_nodes: Dict(:item => 15, :user => 10)
173+
num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20)
174+
```
175+
"""
176+
rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...)
136177

137-
function rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, (m1, m2)::NTuple{2,Int}; bidirected=true,
138-
node_t = (:A, :B), edge_t = :to, kws...)
139-
if edge_t isa Symbol
140-
edge_t = (edge_t, edge_t)
178+
function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true,
179+
node_t = (:A, :B), edge_t::Symbol = :to, kws...)
180+
if m isa Integer
181+
m12 = m21 = m
182+
else
183+
m12, m21 = m
141184
end
142-
return rand_heterograph(Dict(node_t[1] => n1, node_t[2] => n2),
143-
Dict((node_t[1], edge_t[1], node_t[2]) => m1, (node_t[2], edge_t[2], node_t[1]) => m2);
185+
186+
return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2),
187+
Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21);
144188
bidirected, kws...)
145189
end
146190

GNNGraphs/src/transform.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ then all new self loops will have no weight.
5757
If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same.
5858
This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type.
5959
"""
60-
function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V}
60+
function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
61+
6162
function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
6263
get(g.graph, edge_t, (nothing, nothing, nothing))[3]
6364
end
@@ -69,13 +70,17 @@ function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where
6970
n = get(g.num_nodes, src_t, 0)
7071

7172
if haskey(g.graph, edge_t)
72-
x = g.graph[edge_t]
73-
s, t = x[1:2]
73+
s, t = g.graph[edge_t][1:2]
7474
nodes = convert(typeof(s), [1:n;])
7575
s = [s; nodes]
7676
t = [t; nodes]
7777
else
78-
nodes = convert(T, [1:n;])
78+
if !isempty(g.graph)
79+
T = typeof(first(values(g.graph))[1])
80+
nodes = convert(T, [1:n;])
81+
else
82+
nodes = [1:n;]
83+
end
7984
s = nodes
8085
t = nodes
8186
end
@@ -518,7 +523,6 @@ end
518523
Return a new graph obtained from `g` by adding random edges, based on a specified `perturb_ratio`.
519524
The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph.
520525
These new edges are added without creating self-loops.
521-
Optionally, a random `seed` can be provided to ensure reproducible perturbations.
522526
523527
The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges.
524528
The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges.

0 commit comments

Comments
 (0)