Skip to content

Commit cb01c4b

Browse files
rand_bipartite
1 parent 87c062b commit cb01c4b

File tree

6 files changed

+68
-34
lines changed

6 files changed

+68
-34
lines changed

GNNGraphs/src/generate.jl

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ function rand_graph(n::Integer, m::Integer; seed=-1, kws...)
4848
return rand_graph(rng, n, m; kws...)
4949
end
5050

51-
function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; bidirected = true, edge_weight = nothing, kws...)
51+
function rand_graph(rng::AbstractRNG, n::Integer, m::Integer;
52+
bidirected::Bool = true,
53+
edge_weight::Union{AbstractVector, Nothing} = nothing, kws...)
5254
if bidirected
5355
@assert iseven(m) lazy"Need even number of edges for bidirected graphs, given m=$m."
5456
s, t, _ = _rand_edges(rng, n, m ÷ 2; directed=false, self_loops=false)
@@ -63,13 +65,13 @@ function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; bidirected = true,
6365
end
6466

6567
"""
66-
rand_heterograph(n, m; seed=-1, bidirected=false, kws...)
68+
rand_heterograph([rng,] n, m; bidirected=false, kws...)
6769
68-
Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges
70+
Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges
6971
specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs
7072
specifing node/edge types and their numbers.
7173
72-
Use a `seed > 0` for reproducibility.
74+
Pass a random number generator as a first argument to make the generation reproducible.
7375
7476
Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge.
7577
Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)`
@@ -92,16 +94,25 @@ function rand_heterograph end
9294
# for generic iterators of pairs
9395
rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...)
9496

95-
function rand_heterograph(n::NDict, m::EDict; bidirected = false, seed = -1, kws...)
96-
rng = seed > 0 ? MersenneTwister(seed) : Random.GLOBAL_RNG
97+
function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...)
98+
if seed != -1
99+
Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph)
100+
rng = MersenneTwister(seed)
101+
else
102+
rng = Random.default_rng()
103+
end
104+
return rand_heterograph(rng, n, m; kws...)
105+
end
106+
107+
function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...)
97108
if bidirected
98109
return _rand_bidirected_heterograph(rng, n, m; kws...)
99110
end
100111
graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m))
101112
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
102113
end
103114

104-
function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...)
115+
function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...)
105116
for k in keys(m)
106117
if reverse(k) keys(m)
107118
@assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs."
@@ -121,35 +132,58 @@ end
121132

122133

123134
"""
124-
rand_bipartite_heterograph(n1, n2, m; [bidirected, seed, node_t, edge_t, kws...])
125-
rand_bipartite_heterograph((n1, n2), m; ...)
126-
rand_bipartite_heterograph((n1, n2), (m1, m2); ...)
135+
rand_bipartite_heterograph([rng,]
136+
(n1, n2), (m12, m21);
137+
bidirected = true,
138+
node_t = (:A, :B),
139+
edge_t = :to,
140+
kws...)
127141
128-
Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges
129-
specified by `n1`, `n2` and `m1` and `m2` respectively.
142+
Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph.
143+
The graph will have two types of nodes, and edges will only connect nodes of different types.
130144
131-
See [`rand_heterograph`](@ref) for a more general version.
145+
The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type.
146+
The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2`
147+
and vice versa.
132148
133-
# Keyword arguments
149+
The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments,
150+
which default to `(:A, :B)` and `:to` respectively.
134151
135-
- `bidirected`: whether to generate a bidirected graph. Default is `true`.
136-
- `seed`: random seed. Default is `-1` (no seed).
137-
- `node_t`: node types. If `bipartite=true`, this should be a tuple of two node types, otherwise it should be a single node type.
138-
- `edge_t`: edge types. If `bipartite=true`, this should be a tuple of two edge types, otherwise it should be a single edge type.
139-
"""
140-
function rand_bipartite_heterograph end
152+
If `bidirected=true` (default), the reverse edge of each edge will be present. In this case
153+
`m12 == m21` is required.
141154
142-
rand_bipartite_heterograph(n1::Int, n2::Int, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...)
155+
A random number generator can be passed as the first argument to make the generation reproducible.
143156
144-
rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...)
157+
Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor.
145158
146-
function rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, (m1, m2)::NTuple{2,Int}; bidirected=true,
147-
node_t = (:A, :B), edge_t = :to, kws...)
148-
if edge_t isa Symbol
149-
edge_t = (edge_t, edge_t)
159+
See [`rand_heterograph`](@ref) for a more general version.
160+
161+
# Examples
162+
163+
```julia-repl
164+
julia> g = rand_bipartite_heterograph((10, 15), 20)
165+
GNNHeteroGraph:
166+
num_nodes: (:A => 10, :B => 15)
167+
num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20)
168+
169+
julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false)
170+
GNNHeteroGraph:
171+
num_nodes: Dict(:item => 15, :user => 10)
172+
num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20)
173+
```
174+
"""
175+
rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...)
176+
177+
function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true,
178+
node_t = (:A, :B), edge_t::Symbol = :to, kws...)
179+
if m isa Integer
180+
m12 = m21 = m
181+
else
182+
m12, m21 = m
150183
end
151-
return rand_heterograph(Dict(node_t[1] => n1, node_t[2] => n2),
152-
Dict((node_t[1], edge_t[1], node_t[2]) => m1, (node_t[2], edge_t[2], node_t[1]) => m2);
184+
185+
return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2),
186+
Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21);
153187
bidirected, kws...)
154188
end
155189

GNNGraphs/test/generate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ end
8181
end
8282

8383
@testset "rand_bipartite_heterograph" begin
84-
g = rand_bipartite_heterograph(10, 15, 20)
84+
g = rand_bipartite_heterograph((10, 15), (20, 20))
8585
@test g.num_nodes == Dict(:A => 10, :B => 15)
8686
@test g.num_edges == Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20)
8787
sA, tB = edge_index(g, (:A, :to, :B))

GNNGraphs/test/gnnheterograph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ end
123123

124124
@testset "get/set node features" begin
125125
d, n = 3, 5
126-
g = rand_bipartite_heterograph(n, 2*n, 15)
126+
g = rand_bipartite_heterograph((n, 2*n), 15)
127127
g[:A].x = rand(Float32, d, n)
128128
g[:B].y = rand(Float32, d, 2*n)
129129

@@ -133,7 +133,7 @@ end
133133

134134
@testset "add_edges" begin
135135
d, n = 3, 5
136-
g = rand_bipartite_heterograph(n, 2 * n, 15)
136+
g = rand_bipartite_heterograph((n, 2 * n), 15)
137137
s, t = [1, 2, 3], [3, 2, 1]
138138
## Keep the same ntypes - construct with args
139139
g1 = add_edges(g, (:A, :rel1, :B), s, t)

GNNLux/test/layers/basic_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testitem "layers/basic" setup=[SharedTestSetup] begin
22
rng = StableRNG(17)
3-
g = rand_graph(10, 40, seed=17)
3+
g = rand_graph(rng, 10, 40)
44
x = randn(rng, Float32, 3, 10)
55

66
@testset "GNNLayer" begin

GNNLux/test/layers/conv_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testitem "layers/conv" setup=[SharedTestSetup] begin
22
rng = StableRNG(1234)
3-
g = rand_graph(10, 40, seed=1234)
3+
g = rand_graph(rng, 10, 40)
44
in_dims = 3
55
out_dims = 5
66
x = randn(rng, Float32, in_dims, 10)

test/layers/heteroconv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testset "HeteroGraphConv" begin
22
d, n = 3, 5
3-
g = rand_bipartite_heterograph(n, 2*n, 15)
3+
g = rand_bipartite_heterograph((n, 2*n), 15)
44
hg = rand_bipartite_heterograph((2,3), 6)
55

66
model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d),

0 commit comments

Comments
 (0)