Skip to content

Commit c13af48

Browse files
add sample_neighbors (#93)
* add GNNGraph(num_nodes) * sample_neighbors * fix dense constructor
1 parent dbd64d5 commit c13af48

File tree

9 files changed

+269
-15
lines changed

9 files changed

+269
-15
lines changed

docs/src/api/gnngraph.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,11 @@ Private = false
6767
```@docs
6868
Graphs.intersect
6969
```
70+
71+
## Sampling
72+
73+
```@autodocs
74+
Modules = [GraphNeuralNetworks.GNNGraphs]
75+
Pages = ["sampling.jl"]
76+
Private = false
77+
```

src/GNNGraphs/GNNGraphs.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ include("generate.jl")
5757
export rand_graph,
5858
knn_graph
5959

60+
include("sampling.jl")
61+
export sample_neighbors
62+
6063
include("operators.jl")
6164
# Base.intersect
6265

src/GNNGraphs/gnngraph.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,25 +117,25 @@ end
117117

118118
@functor GNNGraph
119119

120-
function GNNGraph(data;
120+
function GNNGraph(data::D;
121121
num_nodes = nothing,
122122
graph_indicator = nothing,
123123
graph_type = :coo,
124124
dir = :out,
125125
ndata = (;),
126126
edata = (;),
127127
gdata = (;),
128-
)
128+
) where D <: Union{COO_T, ADJMAT_T, ADJLIST_T}
129129

130130
@assert graph_type [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
131131
@assert dir [:in, :out]
132132

133133
if graph_type == :coo
134134
graph, num_nodes, num_edges = to_coo(data; num_nodes, dir)
135135
elseif graph_type == :dense
136-
graph, num_nodes, num_edges = to_dense(data; dir)
136+
graph, num_nodes, num_edges = to_dense(data; num_nodes, dir)
137137
elseif graph_type == :sparse
138-
graph, num_nodes, num_edges = to_sparse(data; dir)
138+
graph, num_nodes, num_edges = to_sparse(data; num_nodes, dir)
139139
end
140140

141141
num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1
@@ -150,9 +150,9 @@ function GNNGraph(data;
150150
ndata, edata, gdata)
151151
end
152152

153-
function GNNGraph(n::T; graph_type=:coo, kws...) where {T<:Integer}
153+
function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T<:Integer}
154154
s, t = T[], T[]
155-
return GNNGraph(s, t; graph_type, num_nodes=n, kws...)
155+
return GNNGraph(s, t; num_nodes, kws...)
156156
end
157157

158158
# COO convenience constructors
@@ -182,10 +182,10 @@ function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata, grap
182182
if graph_type == :coo
183183
graph, num_nodes, num_edges = to_coo(g.graph; g.num_nodes)
184184
elseif graph_type == :dense
185-
graph, num_nodes, num_edges = to_dense(g.graph)
185+
graph, num_nodes, num_edges = to_dense(g.graph; g.num_nodes)
186186
elseif graph_type == :sparse
187-
graph, num_nodes, num_edges = to_sparse(g.graph)
188-
end
187+
graph, num_nodes, num_edges = to_sparse(g.graph; g.num_nodes)
188+
end
189189
@assert num_nodes == g.num_nodes
190190
@assert num_edges == g.num_edges
191191
else

src/GNNGraphs/query.jl

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ Graphs.ne(g::GNNGraph) = g.num_edges
5151
Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes
5252
Graphs.vertices(g::GNNGraph) = 1:g.num_nodes
5353

54+
function Graphs.neighbors(g::GNNGraph, i; dir=:out)
55+
@assert dir (:in, :out)
56+
if dir == :out
57+
outneighbors(g, i)
58+
else
59+
inneighbors(g, i)
60+
end
61+
end
62+
5463
function Graphs.outneighbors(g::GNNGraph{<:COO_T}, i::Integer)
5564
s, t = edge_index(g)
5665
return t[s .== i]
@@ -76,6 +85,7 @@ Graphs.is_directed(::Type{<:GNNGraph}) = true
7685

7786
"""
7887
adjacency_list(g; dir=:out)
88+
adjacency_list(g, nodes; dir=:out)
7989
8090
Return the adjacency list representation (a vector of vectors)
8191
of the graph `g`.
@@ -84,13 +94,44 @@ Calling `a` the adjacency list, if `dir=:out` than
8494
`a[i]` will contain the neighbors of node `i` through
8595
outgoing edges. If `dir=:in`, it will contain neighbors from
8696
incoming edges instead.
97+
98+
If `nodes` is given, return the neighborhood of the nodes in `nodes` only.
8799
"""
88-
function adjacency_list(g::GNNGraph; dir=:out)
100+
function adjacency_list(g::GNNGraph, nodes; dir=:out, with_eid=false)
89101
@assert dir [:out, :in]
90-
fneighs = dir == :out ? outneighbors : inneighbors
91-
return [fneighs(g, i) for i in 1:g.num_nodes]
102+
s, t = edge_index(g)
103+
if dir == :in
104+
s, t = t, s
105+
end
106+
T = eltype(s)
107+
idict = 0
108+
dmap = Dict(n => (idict += 1) for n in nodes)
109+
adjlist = [T[] for _=1:length(dmap)]
110+
eidlist = [T[] for _=1:length(dmap)]
111+
for (eid, (i, j)) in enumerate(zip(s, t))
112+
inew = get(dmap, i, 0)
113+
inew == 0 && continue
114+
push!(adjlist[inew], j)
115+
push!(eidlist[inew], eid)
116+
end
117+
if with_eid
118+
return adjlist, eidlist
119+
else
120+
return adjlist
121+
end
92122
end
93123

124+
# function adjacency_list(g::GNNGraph, nodes; dir=:out)
125+
# @assert dir ∈ [:out, :in]
126+
# fneighs = dir == :out ? outneighbors : inneighbors
127+
# return [fneighs(g, i) for i in nodes]
128+
# end
129+
130+
131+
132+
adjacency_list(g::GNNGraph; dir=:out) = adjacency_list(g, 1:g.num_nodes; dir)
133+
134+
94135
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=eltype(g); dir=:out)
95136
if g.graph[1] isa CuVector
96137
# TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
@@ -140,7 +181,8 @@ Return a vector containing the degrees of the nodes in `g`.
140181
In alternative, you can also pass a vector of weights to be used
141182
instead of the graph's own weights.
142183
"""
143-
function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out, edge_weight=true)
184+
function Graphs.degree(g::GNNGraph{<:COO_T}, T::TT=nothing; dir=:out, edge_weight=true) where
185+
TT<:Union{Nothing, Type{<:Number}}
144186
s, t = edge_index(g)
145187

146188
edge_weight = _get_edge_weight(g, edge_weight)
@@ -157,7 +199,11 @@ function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out, edge_weight=tr
157199
return degs
158200
end
159201

160-
function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=nothing; dir=:out, edge_weight=true)
202+
# TODO:: Make efficient
203+
Graphs.degree(g::GNNGraph, i::Union{Int, AbstractVector}; dir=:out) = degree(g; dir)[i]
204+
205+
function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT=nothing; dir=:out, edge_weight=true) where TT
206+
TT<:Union{Nothing, Type{<:Number}}
161207
# edge_weight=true or edge_weight=nothing act the same here
162208
@assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations"
163209
@assert dir (:in, :out, :both)

src/GNNGraphs/sampling.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
sample_neighbors(g, nodes, K=-1; dir=:in, replace=false, dropnodes=false)
3+
4+
Sample neighboring edges of the given nodes and return the induced subgraph.
5+
For each node, a number of inbound (or outbound when `dir = :out``) edges will be randomly chosen.
6+
If `dropnodes=false`, the graph returned will then contain all the nodes in the original graph,
7+
but only the sampled edges.
8+
9+
The returned graph will contain an edge feature `EID` corresponding to the id of the edge
10+
in the original graph. If `dropnodes=true`, it will also contain a node feature `NID` with
11+
the node ids in the original graph.
12+
13+
# Arguments
14+
15+
- `g`. The graph.
16+
- `nodes`. A list of node IDs to sample neighbors from.
17+
- `K`. The maximum number of edges to be sampled for each node.
18+
If -1, all the neighboring edges will be selected.
19+
- `dir`. Determines whether to sample inbound (`:in`) or outbound (``:out`) edges (Default `:in`).
20+
- `replace`. If `true`, sample with replacement.
21+
- `dropnodes`. If `true`, the resulting subgraph will contain only the nodes involved in the sampled edges.
22+
23+
# Examples
24+
25+
```julia
26+
julia> g = rand_graph(20, 100)
27+
GNNGraph:
28+
num_nodes = 20
29+
num_edges = 100
30+
31+
julia> sample_neighbors(g, 2:3)
32+
GNNGraph:
33+
num_nodes = 20
34+
num_edges = 9
35+
edata:
36+
EID => (9,)
37+
38+
julia> sg = sample_neighbors(g, 2:3, dropnodes=true)
39+
GNNGraph:
40+
num_nodes = 10
41+
num_edges = 9
42+
ndata:
43+
NID => (10,)
44+
edata:
45+
EID => (9,)
46+
47+
julia> sg.ndata.NID
48+
10-element Vector{Int64}:
49+
2
50+
3
51+
17
52+
14
53+
18
54+
15
55+
16
56+
20
57+
7
58+
10
59+
60+
julia> sample_neighbors(g, 2:3, 5, replace=true)
61+
GNNGraph:
62+
num_nodes = 20
63+
num_edges = 10
64+
edata:
65+
EID => (10,)
66+
```
67+
"""
68+
function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K=-1;
69+
dir=:in, replace=false, dropnodes=false)
70+
@assert dir (:in, :out)
71+
_, eidlist = adjacency_list(g, nodes; dir, with_eid=true)
72+
for i in 1:length(eidlist)
73+
if replace
74+
k = K > 0 ? K : length(eidlist[i])
75+
else
76+
k = K > 0 ? min(length(eidlist[i]), K) : length(eidlist[i])
77+
end
78+
eidlist[i] = StatsBase.sample(eidlist[i], k; replace)
79+
end
80+
eids = reduce(vcat, eidlist)
81+
s, t = edge_index(g)
82+
w = get_edge_weight(g)
83+
s = s[eids]
84+
t = t[eids]
85+
w = isnothing(w) ? nothing : w[eids]
86+
87+
edata = getobs(g.edata, eids)
88+
edata = (edata..., EID = eids)
89+
90+
num_edges = length(eids)
91+
92+
if !dropnodes
93+
graph = (s, t, w)
94+
95+
gnew = GNNGraph(graph,
96+
g.num_nodes, num_edges, g.num_graphs,
97+
g.graph_indicator,
98+
g.ndata, edata, g.gdata)
99+
else
100+
nodes_other = dir == :in ? setdiff(s, nodes) : setdiff(t, nodes)
101+
nodes_all = [nodes; nodes_other]
102+
nodemap = Dict(n => i for (i, n) in enumerate(nodes_all))
103+
s = [nodemap[s] for s in s]
104+
t = [nodemap[t] for t in t]
105+
graph = (s, t, w)
106+
graph_indicator = g.graph_indicator !== nothing ? g.graph_indicator[nodes_all] : nothing
107+
num_nodes = length(nodes_all)
108+
ndata = getobs(g.ndata, nodes_all)
109+
ndata = (ndata..., NID = nodes_all)
110+
111+
gnew = GNNGraph(graph,
112+
num_nodes, num_edges, g.num_graphs,
113+
graph_indicator,
114+
ndata, edata, g.gdata)
115+
end
116+
return gnew
117+
end

src/GNNGraphs/transform.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,23 @@ function add_edges(g::GNNGraph{<:COO_T},
120120
g.ndata, edata, g.gdata)
121121
end
122122

123+
### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges)
124+
# function Graphs.add_edge!(g::GNNGraph{<:COO_T}, snew::T, tnew::T; edata=nothing) where T<:Union{Integer, AbstractVector}
125+
# s, t = edge_index(g)
126+
# @assert length(snew) == length(tnew)
127+
# # TODO remove this constraint
128+
# @assert get_edge_weight(g) === nothing
129+
130+
# edata = normalize_graphdata(edata, default_name=:e, n=length(snew))
131+
# edata = cat_features(g.edata, edata)
132+
133+
# s, t = edge_index(g)
134+
# append!(s, snew)
135+
# append!(t, tnew)
136+
# g.num_edges += length(snew)
137+
# return true
138+
# end
139+
123140

124141
"""
125142
add_nodes(g::GNNGraph, n; [ndata])

test/GNNGraphs/gnngraph.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@
3131
@test g.num_edges == 2
3232
end
3333

34+
@testset "Constructor: integer" begin
35+
g = GNNGraph(10, graph_type=GRAPH_T)
36+
@test g.num_nodes == 10
37+
@test g.num_edges == 0
38+
39+
g2 = rand_graph(10, 30, graph_type=GRAPH_T)
40+
G = typeof(g2)
41+
g = G(10)
42+
@test g.num_nodes == 10
43+
@test g.num_edges == 0
44+
end
45+
3446
@testset "symmetric graph" begin
3547
s = [1, 1, 2, 2, 3, 3, 4, 4]
3648
t = [2, 4, 1, 3, 2, 4, 1, 3]

test/GNNGraphs/sampling.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
@testset "sampling.jl" begin
2+
if GRAPH_T == :coo
3+
@testset "sample_neighbors" begin
4+
# replace = false
5+
dir = :in
6+
nodes = 2:3
7+
g = rand_graph(10, 40, bidirected=false, graph_type=GRAPH_T)
8+
sg = sample_neighbors(g, nodes; dir)
9+
@test sg.num_nodes == 10
10+
@test sg.num_edges == sum(degree(g, i; dir) for i in nodes)
11+
@test size(sg.edata.EID) == (sg.num_edges,)
12+
@test length(union(sg.edata.EID)) == length(sg.edata.EID)
13+
adjlist = adjacency_list(g; dir)
14+
s, t = edge_index(sg)
15+
@test all(t .∈ Ref(nodes))
16+
for i in nodes
17+
@test sort(neighbors(sg, i; dir)) == sort(neighbors(g, i; dir))
18+
end
19+
20+
# replace = true
21+
dir = :out
22+
nodes = 2:3
23+
K = 2
24+
g = rand_graph(10, 40, bidirected=false, graph_type=GRAPH_T)
25+
sg = sample_neighbors(g, nodes, K; dir, replace=true)
26+
@test sg.num_nodes == 10
27+
@test sg.num_edges == sum(K for i in nodes)
28+
@test size(sg.edata.EID) == (sg.num_edges,)
29+
adjlist = adjacency_list(g; dir)
30+
s, t = edge_index(sg)
31+
@test all(s .∈ Ref(nodes))
32+
for i in nodes
33+
@test issubset(neighbors(sg, i; dir), adjlist[i])
34+
end
35+
36+
# dropnodes = true
37+
dir = :in
38+
nodes = 2:3
39+
g = rand_graph(10, 40, bidirected=false, graph_type=GRAPH_T)
40+
g = GNNGraph(g, ndata=(x1=rand(10),), edata=(e1=rand(40),))
41+
sg = sample_neighbors(g, nodes; dir, dropnodes=true)
42+
@test sg.num_edges == sum(degree(g, i; dir) for i in nodes)
43+
@test size(sg.edata.EID) == (sg.num_edges,)
44+
@test size(sg.ndata.NID) == (sg.num_nodes,)
45+
@test sg.edata.e1 == g.edata.e1[sg.edata.EID]
46+
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
47+
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
48+
end
49+
end
50+
end

0 commit comments

Comments
 (0)