Skip to content

Commit 47a17ce

Browse files
add seed to rand_graph
1 parent da7b1e4 commit 47a17ce

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/GNNGraphs/generate.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
rand_graph(n, m; bidirected=true, kws...)
2+
rand_graph(n, m; bidirected=true, seed=-1, kws...)
33
44
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes
55
and `m` edges.
@@ -8,6 +8,8 @@ If `bidirected=true` the reverse edge of each edge will be present.
88
If `bidirected=false` instead, `m` unrelated edges are generated.
99
In any case, the output graph will contain no self-loops or multi-edges.
1010
11+
Use a `seed > 0` for reproducibility.
12+
1113
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.
1214
1315
# Usage
@@ -43,10 +45,10 @@ julia> edge_index(g)
4345
4446
```
4547
"""
46-
function rand_graph(n::Integer, m::Integer; bidirected=true, kws...)
48+
function rand_graph(n::Integer, m::Integer; bidirected=true, seed=-1, kws...)
4749
if bidirected
4850
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
4951
end
5052
m2 = bidirected ? m÷2 : m
51-
return GNNGraph(Graphs.erdos_renyi(n, m2, is_directed=!bidirected); kws...)
53+
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed=!bidirected, seed); kws...)
5254
end

test/GNNGraphs/generate.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
m2 = m ÷ 2
55
x = rand(3, n)
66
e = rand(4, m2)
7+
78
g = rand_graph(n, m, ndata=x, edata=e, graph_type=GRAPH_T)
89
@test g.num_nodes == n
910
@test g.num_edges == m
@@ -15,8 +16,12 @@
1516
@test g.edata.e[:,1:m2] == e
1617
@test g.edata.e[:,m2+1:end] == e
1718
end
18-
g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T)
19+
20+
g = rand_graph(n, m, bidirected=false, seed=17, graph_type=GRAPH_T)
1921
@test g.num_nodes == n
2022
@test g.num_edges == m
23+
24+
g2 = rand_graph(n, m, bidirected=false, seed=17, graph_type=GRAPH_T)
25+
@test edge_index(g2) == edge_index(g)
2126
end
2227
end

0 commit comments

Comments
 (0)