Skip to content

Commit 6b8a7fb

Browse files
authored
feat: add NeighborLoader (#497)
1 parent 9d9e8d0 commit 6b8a7fb

File tree

8 files changed

+253
-2
lines changed

8 files changed

+253
-2
lines changed

GNNGraphs/src/sampling.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
177177

178178
node_map = Dict(node => i for (i, node) in enumerate(nodes))
179179

180+
edge_list = [collect(t) for t in zip(edge_index(graph)[1],edge_index(graph)[2])]
181+
180182
# Collect edges to add
181183
source = Int[]
182184
target = Int[]
@@ -187,8 +189,7 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
187189
if neighbor in keys(node_map)
188190
push!(target, node_map[node])
189191
push!(source, node_map[neighbor])
190-
191-
eindex = findfirst(x -> x == [neighbor, node], edge_index(graph))
192+
eindex = findfirst(x -> x == [neighbor, node], edge_list)
192193
push!(eindices, eindex)
193194
end
194195
end

GraphNeuralNetworks/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
99
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1010
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
1111
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
12+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1415
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -28,6 +29,7 @@ CUDA = "4, 5"
2829
ChainRulesCore = "1"
2930
Flux = "0.14"
3031
Functors = "0.4.1"
32+
Graphs = "1.12"
3133
GNNGraphs = "1.0"
3234
GNNlib = "0.2"
3335
LinearAlgebra = "1"

GraphNeuralNetworks/docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ makedocs(;
4646
"Message Passing" => "api/messagepassing.md",
4747
"Heterogeneous Graphs" => "api/heterograph.md",
4848
"Temporal Graphs" => "api/temporalgraph.md",
49+
"Samplers" => "api/samplers.md",
4950
"Utils" => "api/utils.md",
5051
],
5152
"Developer Notes" => "dev.md",
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
```@meta
2+
CurrentModule = GraphNeuralNetworks
3+
```
4+
5+
# Samplers
6+
7+
8+
## Docs
9+
10+
```@autodocs
11+
Modules = [GraphNeuralNetworks]
12+
Pages = ["samplers.jl"]
13+
Private = false
14+
```

GraphNeuralNetworks/src/GraphNeuralNetworks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using NNlib: scatter, gather
1010
using ChainRulesCore
1111
using Reexport
1212
using MLUtils: zeros_like
13+
using Graphs: Graphs
1314

1415
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
1516
check_num_nodes, check_num_edges,
@@ -66,4 +67,7 @@ export GlobalPool,
6667

6768
include("deprecations.jl")
6869

70+
include("samplers.jl")
71+
export NeighborLoader
72+
6973
end
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size])
3+
4+
A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs).
5+
It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training
6+
originally introduced in "Inductive Representation Learning on Large Graphs" paper.
7+
[see https://arxiv.org/abs/1706.02216]
8+
9+
# Fields
10+
- `graph::GNNGraph`: The input graph.
11+
- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer.
12+
- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling.
13+
- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors).
14+
- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`.
15+
16+
# Usage
17+
```jldoctest
18+
julia> loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2)
19+
20+
julia> batch_counter = 0
21+
julia> for mini_batch_gnn in loader
22+
batch_counter += 1
23+
println("Batch ", batch_counter, ": Nodes in mini-batch graph: ", nv(mini_batch_gnn))
24+
```
25+
"""
26+
struct NeighborLoader
27+
graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl)
28+
num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer
29+
input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling)
30+
num_layers::Int # Number of layers for neighborhood expansion
31+
batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given
32+
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation
33+
end
34+
35+
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}=nothing,
36+
num_layers::Int, batch_size::Union{Int, Nothing}=nothing)
37+
return NeighborLoader(graph, num_neighbors, input_nodes === nothing ? collect(1:graph.num_nodes) : input_nodes, num_layers,
38+
batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}())
39+
end
40+
41+
# Function to get cached neighbors or compute them
42+
function get_neighbors(loader::NeighborLoader, node::Int)
43+
if haskey(loader.neighbors_cache, node)
44+
return loader.neighbors_cache[node]
45+
else
46+
neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph
47+
loader.neighbors_cache[node] = neighbors
48+
return neighbors
49+
end
50+
end
51+
52+
# Function to sample neighbors for a given node at a specific layer
53+
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)
54+
neighbors = get_neighbors(loader, node)
55+
if isempty(neighbors)
56+
return Int[]
57+
else
58+
num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer
59+
return rand(neighbors, num_samples) # Randomly sample neighbors
60+
end
61+
end
62+
63+
# Iterator protocol for NeighborLoader with lazy batch loading
64+
function Base.iterate(loader::NeighborLoader, state=1)
65+
if state > length(loader.input_nodes)
66+
return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number)
67+
end
68+
69+
# Determine the size of the current batch
70+
batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch
71+
batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes
72+
73+
# Set for tracking the subgraph nodes
74+
subgraph_nodes = Set(batch_nodes)
75+
76+
for node in batch_nodes
77+
# Initialize current layer of nodes (starting with the node itself)
78+
sampled_neighbors = Set([node])
79+
80+
# For each GNN layer, sample the neighborhood
81+
for layer in 1:loader.num_layers
82+
new_neighbors = Set{Int}()
83+
for n in sampled_neighbors
84+
neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer
85+
new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set
86+
end
87+
sampled_neighbors = new_neighbors
88+
subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors
89+
end
90+
end
91+
92+
# Collect subgraph nodes and their features
93+
subgraph_node_list = collect(subgraph_nodes)
94+
95+
if isempty(subgraph_node_list)
96+
return GNNGraph(), state + batch_size
97+
end
98+
99+
mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes
100+
101+
# Continue iteration for the next batch
102+
return mini_batch_gnn, state + batch_size
103+
end

GraphNeuralNetworks/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ tests = [
3030
"layers/temporalconv",
3131
"layers/pool",
3232
"examples/node_classification_cora",
33+
"samplers"
3334
]
3435

3536
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Helper function to create a simple graph with node features using GNNGraph
2+
function create_test_graph()
3+
source = [1, 2, 3, 4] # Define source nodes of edges
4+
target = [2, 3, 4, 5] # Define target nodes of edges
5+
node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes)
6+
7+
return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features
8+
end
9+
10+
# Tests for NeighborLoader structure and its functionalities
11+
@testset "NeighborLoader tests" begin
12+
13+
# 1. Basic functionality: Check neighbor sampling and subgraph creation
14+
@testset "Basic functionality" begin
15+
g = create_test_graph()
16+
17+
# Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2
18+
loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2)
19+
20+
mini_batch_gnn, next_state = iterate(loader)
21+
22+
# Test if the mini-batch graph is not empty
23+
@test !isempty(mini_batch_gnn.graph)
24+
25+
num_sampled_nodes = mini_batch_gnn.num_nodes
26+
println("Number of nodes in mini-batch: ", num_sampled_nodes)
27+
28+
@test num_sampled_nodes == 2
29+
30+
# Test if there are edges in the subgraph
31+
@test mini_batch_gnn.num_edges > 0
32+
end
33+
34+
# 2. Edge case: Single node with no neighbors
35+
@testset "Single node with no neighbors" begin
36+
g = SimpleDiGraph(1) # A graph with a single node and no edges
37+
node_features = rand(Float32, 5, 1)
38+
graph = GNNGraph(g, ndata = node_features)
39+
40+
loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1)
41+
42+
mini_batch_gnn, next_state = iterate(loader)
43+
44+
# Test if the mini-batch graph contains only one node
45+
@test size(mini_batch_gnn.x, 2) == 1
46+
end
47+
48+
# 3. Edge case: A node with no outgoing edges (isolated node)
49+
@testset "Node with no outgoing edges" begin
50+
g = SimpleDiGraph(2) # Graph with 2 nodes, no edges
51+
node_features = rand(Float32, 5, 2)
52+
graph = GNNGraph(g, ndata = node_features)
53+
54+
loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1)
55+
56+
mini_batch_gnn, next_state = iterate(loader)
57+
58+
# Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled)
59+
@test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes
60+
end
61+
62+
# 4. Edge case: A fully connected graph
63+
@testset "Fully connected graph" begin
64+
g = SimpleDiGraph(3)
65+
add_edge!(g, 1, 2)
66+
add_edge!(g, 2, 3)
67+
add_edge!(g, 3, 1)
68+
node_features = rand(Float32, 5, 3)
69+
graph = GNNGraph(g, ndata = node_features)
70+
71+
loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2)
72+
73+
mini_batch_gnn, next_state = iterate(loader)
74+
75+
# Test if all nodes are included in the mini-batch since it's fully connected
76+
@test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included
77+
end
78+
79+
# 5. Edge case: More layers than the number of neighbors
80+
@testset "More layers than available neighbors" begin
81+
g = SimpleDiGraph(3)
82+
add_edge!(g, 1, 2)
83+
add_edge!(g, 2, 3)
84+
node_features = rand(Float32, 5, 3)
85+
graph = GNNGraph(g, ndata = node_features)
86+
87+
# Test with 3 layers but only enough connections for 2 layers
88+
loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3)
89+
90+
mini_batch_gnn, next_state = iterate(loader)
91+
92+
# Test if the mini-batch graph contains all available nodes
93+
@test size(mini_batch_gnn.x, 2) == 1
94+
end
95+
96+
# 6. Edge case: Large batch size greater than the number of input nodes
97+
@testset "Large batch size" begin
98+
g = create_test_graph()
99+
100+
# Define NeighborLoader with a larger batch size than input nodes
101+
loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10)
102+
103+
mini_batch_gnn, next_state = iterate(loader)
104+
105+
# Test if the mini-batch graph is not empty
106+
@test !isempty(mini_batch_gnn.graph)
107+
108+
# Test if the correct number of nodes are sampled
109+
@test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected
110+
end
111+
112+
# 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer
113+
@testset "No neighbors sampled" begin
114+
g = create_test_graph()
115+
116+
# Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2
117+
loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2)
118+
119+
mini_batch_gnn, next_state = iterate(loader)
120+
121+
# Test if the mini-batch graph contains only the input nodes
122+
@test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph
123+
end
124+
125+
end

0 commit comments

Comments
 (0)