Skip to content

Commit 2b2ffad

Browse files
committed
Test
1 parent af0455a commit 2b2ffad

File tree

9 files changed

+493
-2
lines changed

9 files changed

+493
-2
lines changed

GNNGraphs/docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/require", "[tex]/ma
1919

2020
makedocs(;
2121
modules = [GNNGraphs],
22-
doctest = true, # TODO enable doctest
22+
doctest = false, # TODO enable doctest
2323
format = Documenter.HTML(; mathengine,
2424
prettyurls = get(ENV, "CI", nothing) == "true",
2525
assets = [],

GNNGraphs/docs/src/api/gnngraph.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Pages = ["query.jl"]
3535
Private = false
3636
```
3737

38+
3839
```@docs
3940
Graphs.neighbors(::GNNGraph, ::Integer)
4041
```

GNNGraphs/docs/src/api/heterograph.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,26 @@ Private = false
2020
Graphs.has_edge(::GNNHeteroGraph, ::Tuple{Symbol, Symbol, Symbol}, ::Integer, ::Integer)
2121
```
2222

23+
## Query
24+
25+
```@autodocs
26+
Modules = [GNNGraphs]
27+
Pages = ["gnnheterograph/query.jl"]
28+
Private = false
29+
```
30+
31+
## Transform
32+
33+
```@autodocs
34+
Modules = [GNNGraphs]
35+
Pages = ["gnnheterograph/transform.jl"]
36+
Private = false
37+
```
38+
39+
## Generate
40+
41+
```@autodocs
42+
Modules = [GNNGraphs]
43+
Pages = ["gnnheterograph/generate.jl"]
44+
Private = false
45+
```

GNNGraphs/src/GNNGraphs.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export GNNGraph,
3030
edge_features,
3131
graph_features
3232

33-
include("gnnheterograph.jl")
33+
include("gnnheterograph/gnnheterograph.jl")
3434
export GNNHeteroGraph,
3535
num_edge_types,
3636
num_node_types,
@@ -44,6 +44,7 @@ export TemporalSnapshotsGNNGraph,
4444
# remove_snapshot!
4545

4646
include("query.jl")
47+
include("gnnheterograph/query.jl")
4748
export adjacency_list,
4849
edge_index,
4950
get_edge_weight,
@@ -65,6 +66,7 @@ export adjacency_list,
6566
khop_adj
6667

6768
include("transform.jl")
69+
include("gnnheterograph/transform.jl")
6870
export add_nodes,
6971
add_edges,
7072
add_self_loops,
@@ -88,6 +90,7 @@ export add_nodes,
8890
blockdiag
8991

9092
include("generate.jl")
93+
include("gnnheterograph/generate.jl")
9194
export rand_graph,
9295
rand_heterograph,
9396
rand_bipartite_heterograph,
@@ -104,6 +107,7 @@ include("operators.jl")
104107

105108
include("convert.jl")
106109
include("utils.jl")
110+
include("gnnheterograph/utils.jl")
107111
export sort_edge_index, color_refinement
108112

109113
include("gatherscatter.jl")
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""
2+
rand_heterograph([rng,] n, m; bidirected=false, kws...)
3+
4+
Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges
5+
specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs
6+
specifing node/edge types and their numbers.
7+
8+
Pass a random number generator as a first argument to make the generation reproducible.
9+
10+
Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge.
11+
Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)`
12+
will be generated.
13+
14+
Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor.
15+
16+
# Examples
17+
18+
```jldoctest
19+
julia> g = rand_heterograph((:user => 10, :movie => 20),
20+
(:user, :rate, :movie) => 30)
21+
GNNHeteroGraph:
22+
num_nodes: (:user => 10, :movie => 20)
23+
num_edges: ((:user, :rate, :movie) => 30,)
24+
```
25+
"""
26+
function rand_heterograph end
27+
28+
# for generic iterators of pairs
29+
rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...)
30+
rand_heterograph(rng::AbstractRNG, n, m; kws...) = rand_heterograph(rng, Dict(n), Dict(m); kws...)
31+
32+
function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...)
33+
if seed != -1
34+
Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph)
35+
rng = MersenneTwister(seed)
36+
else
37+
rng = Random.default_rng()
38+
end
39+
return rand_heterograph(rng, n, m; kws...)
40+
end
41+
42+
function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...)
43+
if bidirected
44+
return _rand_bidirected_heterograph(rng, n, m; kws...)
45+
end
46+
graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m))
47+
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
48+
end
49+
50+
function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...)
51+
for k in keys(m)
52+
if reverse(k) keys(m)
53+
@assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs."
54+
else
55+
m[reverse(k)] = m[k]
56+
end
57+
end
58+
graphs = Dict{EType, Tuple{Vector{Int}, Vector{Int}, Nothing}}()
59+
for k in keys(m)
60+
reverse(k) keys(graphs) && continue
61+
s, t, val = _rand_edges(rng, (n[k[1]], n[k[3]]), m[k])
62+
graphs[k] = s, t, val
63+
graphs[reverse(k)] = t, s, val
64+
end
65+
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
66+
end
67+
68+
69+
"""
70+
rand_bipartite_heterograph([rng,]
71+
(n1, n2), (m12, m21);
72+
bidirected = true,
73+
node_t = (:A, :B),
74+
edge_t = :to,
75+
kws...)
76+
77+
Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph.
78+
The graph will have two types of nodes, and edges will only connect nodes of different types.
79+
80+
The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type.
81+
The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2`
82+
and vice versa.
83+
84+
The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments,
85+
which default to `(:A, :B)` and `:to` respectively.
86+
87+
If `bidirected=true` (default), the reverse edge of each edge will be present. In this case
88+
`m12 == m21` is required.
89+
90+
A random number generator can be passed as the first argument to make the generation reproducible.
91+
92+
Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor.
93+
94+
See [`rand_heterograph`](@ref) for a more general version.
95+
96+
# Examples
97+
98+
```julia-repl
99+
julia> g = rand_bipartite_heterograph((10, 15), 20)
100+
GNNHeteroGraph:
101+
num_nodes: (:A => 10, :B => 15)
102+
num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20)
103+
104+
julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false)
105+
GNNHeteroGraph:
106+
num_nodes: Dict(:item => 15, :user => 10)
107+
num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20)
108+
```
109+
"""
110+
rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...)
111+
112+
function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true,
113+
node_t = (:A, :B), edge_t::Symbol = :to, kws...)
114+
if m isa Integer
115+
m12 = m21 = m
116+
else
117+
m12, m21 = m
118+
end
119+
120+
return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2),
121+
Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21);
122+
bidirected, kws...)
123+
end
124+
File renamed without changes.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""
2+
edge_index(g::GNNHeteroGraph, [edge_t])
3+
4+
Return a tuple containing two vectors, respectively storing the source and target nodes
5+
for each edges in `g` of type `edge_t = (src_t, rel_t, trg_t)`.
6+
7+
If `edge_t` is not provided, it will error if `g` has more than one edge type.
8+
"""
9+
edge_index(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][1:2]
10+
edge_index(g::GNNHeteroGraph{<:COO_T}) = only(g.graph)[2][1:2]
11+
12+
get_edge_weight(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][3]
13+
14+
"""
15+
has_edge(g::GNNHeteroGraph, edge_t, i, j)
16+
17+
Return `true` if there is an edge of type `edge_t` from node `i` to node `j` in `g`.
18+
19+
# Examples
20+
21+
```jldoctest
22+
julia> g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false)
23+
GNNHeteroGraph:
24+
num_nodes: (:A => 2, :B => 2)
25+
num_edges: ((:A, :to, :B) => 4, (:B, :to, :A) => 0)
26+
27+
julia> has_edge(g, (:A,:to,:B), 1, 1)
28+
true
29+
30+
julia> has_edge(g, (:B,:to,:A), 1, 1)
31+
false
32+
```
33+
"""
34+
function Graphs.has_edge(g::GNNHeteroGraph, edge_t::EType, i::Integer, j::Integer)
35+
s, t = edge_index(g, edge_t)
36+
return any((s .== i) .& (t .== j))
37+
end
38+
39+
40+
"""
41+
degree(g::GNNHeteroGraph, edge_type::EType; dir = :in)
42+
43+
Return a vector containing the degrees of the nodes in `g` GNNHeteroGraph
44+
given `edge_type`.
45+
46+
# Arguments
47+
48+
- `g`: A graph.
49+
- `edge_type`: A tuple of symbols `(source_t, edge_t, target_t)` representing the edge type.
50+
- `T`: Element type of the returned vector. If `nothing`, is
51+
chosen based on the graph type. Default `nothing`.
52+
- `dir`: For `dir = :out` the degree of a node is counted based on the outgoing edges.
53+
For `dir = :in`, the ingoing edges are used. If `dir = :both` we have the sum of the two.
54+
Default `dir = :out`.
55+
56+
"""
57+
function Graphs.degree(g::GNNHeteroGraph, edge::EType,
58+
T::TT = nothing; dir = :out) where {
59+
TT <: Union{Nothing, Type{<:Number}}}
60+
61+
s, t = edge_index(g, edge)
62+
63+
T = isnothing(T) ? eltype(s) : T
64+
65+
n_type = dir == :in ? g.ntypes[2] : g.ntypes[1]
66+
67+
return _degree((s, t), T, dir, nothing, g.num_nodes[n_type])
68+
end
69+
70+
"""
71+
graph_indicator(g::GNNHeteroGraph, [node_t])
72+
73+
Return a Dict of vectors containing the graph membership
74+
(an integer from `1` to `g.num_graphs`) of each node in the graph for each node type.
75+
If `node_t` is provided, return the graph membership of each node of type `node_t` instead.
76+
77+
See also [`batch`](@ref).
78+
"""
79+
function graph_indicator(g::GNNHeteroGraph)
80+
return g.graph_indicator
81+
end
82+
83+
function graph_indicator(g::GNNHeteroGraph, node_t::Symbol)
84+
@assert node_t g.ntypes
85+
if isnothing(g.graph_indicator)
86+
gi = ones_like(edge_index(g, first(g.etypes))[1], Int, g.num_nodes[node_t])
87+
else
88+
gi = g.graph_indicator[node_t]
89+
end
90+
return gi
91+
end

0 commit comments

Comments
 (0)