|
| 1 | +""" |
| 2 | + rand_heterograph([rng,] n, m; bidirected=false, kws...) |
| 3 | +
|
| 4 | +Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges |
| 5 | +specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs |
| 6 | +specifing node/edge types and their numbers. |
| 7 | +
|
| 8 | +Pass a random number generator as a first argument to make the generation reproducible. |
| 9 | +
|
| 10 | +Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge. |
| 11 | +Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)` |
| 12 | +will be generated. |
| 13 | +
|
| 14 | +Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor. |
| 15 | +
|
| 16 | +# Examples |
| 17 | +
|
| 18 | +```jldoctest |
| 19 | +julia> g = rand_heterograph((:user => 10, :movie => 20), |
| 20 | + (:user, :rate, :movie) => 30) |
| 21 | +GNNHeteroGraph: |
| 22 | + num_nodes: (:user => 10, :movie => 20) |
| 23 | + num_edges: ((:user, :rate, :movie) => 30,) |
| 24 | +``` |
| 25 | +""" |
| 26 | +function rand_heterograph end |
| 27 | + |
| 28 | +# for generic iterators of pairs |
| 29 | +rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...) |
| 30 | +rand_heterograph(rng::AbstractRNG, n, m; kws...) = rand_heterograph(rng, Dict(n), Dict(m); kws...) |
| 31 | + |
| 32 | +function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...) |
| 33 | + if seed != -1 |
| 34 | + Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph) |
| 35 | + rng = MersenneTwister(seed) |
| 36 | + else |
| 37 | + rng = Random.default_rng() |
| 38 | + end |
| 39 | + return rand_heterograph(rng, n, m; kws...) |
| 40 | +end |
| 41 | + |
| 42 | +function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...) |
| 43 | + if bidirected |
| 44 | + return _rand_bidirected_heterograph(rng, n, m; kws...) |
| 45 | + end |
| 46 | + graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m)) |
| 47 | + return GNNHeteroGraph(graphs; num_nodes = n, kws...) |
| 48 | +end |
| 49 | + |
| 50 | +function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...) |
| 51 | + for k in keys(m) |
| 52 | + if reverse(k) ∈ keys(m) |
| 53 | + @assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs." |
| 54 | + else |
| 55 | + m[reverse(k)] = m[k] |
| 56 | + end |
| 57 | + end |
| 58 | + graphs = Dict{EType, Tuple{Vector{Int}, Vector{Int}, Nothing}}() |
| 59 | + for k in keys(m) |
| 60 | + reverse(k) ∈ keys(graphs) && continue |
| 61 | + s, t, val = _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) |
| 62 | + graphs[k] = s, t, val |
| 63 | + graphs[reverse(k)] = t, s, val |
| 64 | + end |
| 65 | + return GNNHeteroGraph(graphs; num_nodes = n, kws...) |
| 66 | +end |
| 67 | + |
| 68 | + |
| 69 | +""" |
| 70 | + rand_bipartite_heterograph([rng,] |
| 71 | + (n1, n2), (m12, m21); |
| 72 | + bidirected = true, |
| 73 | + node_t = (:A, :B), |
| 74 | + edge_t = :to, |
| 75 | + kws...) |
| 76 | +
|
| 77 | +Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph. |
| 78 | +The graph will have two types of nodes, and edges will only connect nodes of different types. |
| 79 | +
|
| 80 | +The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type. |
| 81 | +The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2` |
| 82 | +and vice versa. |
| 83 | +
|
| 84 | +The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments, |
| 85 | +which default to `(:A, :B)` and `:to` respectively. |
| 86 | +
|
| 87 | +If `bidirected=true` (default), the reverse edge of each edge will be present. In this case |
| 88 | +`m12 == m21` is required. |
| 89 | +
|
| 90 | +A random number generator can be passed as the first argument to make the generation reproducible. |
| 91 | +
|
| 92 | +Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor. |
| 93 | +
|
| 94 | +See [`rand_heterograph`](@ref) for a more general version. |
| 95 | +
|
| 96 | +# Examples |
| 97 | +
|
| 98 | +```julia-repl |
| 99 | +julia> g = rand_bipartite_heterograph((10, 15), 20) |
| 100 | +GNNHeteroGraph: |
| 101 | + num_nodes: (:A => 10, :B => 15) |
| 102 | + num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20) |
| 103 | +
|
| 104 | +julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false) |
| 105 | +GNNHeteroGraph: |
| 106 | + num_nodes: Dict(:item => 15, :user => 10) |
| 107 | + num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20) |
| 108 | +``` |
| 109 | +""" |
| 110 | +rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...) |
| 111 | + |
| 112 | +function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true, |
| 113 | + node_t = (:A, :B), edge_t::Symbol = :to, kws...) |
| 114 | + if m isa Integer |
| 115 | + m12 = m21 = m |
| 116 | + else |
| 117 | + m12, m21 = m |
| 118 | + end |
| 119 | + |
| 120 | + return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2), |
| 121 | + Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21); |
| 122 | + bidirected, kws...) |
| 123 | +end |
| 124 | + |
0 commit comments