Skip to content

Commit 4383a1c

Browse files
finish porting
1 parent b3d528b commit 4383a1c

17 files changed

+643
-612
lines changed

GNNGraphs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1010
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
13+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1314
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1415
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1516
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
@@ -34,6 +35,7 @@ Graphs = "1.4"
3435
KrylovKit = "0.8"
3536
LinearAlgebra = "1"
3637
MLDataDevices = "1.0"
38+
MLDatasets = "0.7.18"
3739
MLUtils = "0.4"
3840
NNlib = "0.9"
3941
NearestNeighbors = "0.4"

GNNGraphs/docs/make.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ makedocs(;
4545
"GNNGraph" => "api/gnngraph.md",
4646
"GNNHeteroGraph" => "api/heterograph.md",
4747
"TemporalSnapshotsGNNGraph" => "api/temporalgraph.md",
48-
"Samplers" => "api/samplers.md",
4948
"Datasets" => "api/datasets.md",
5049
],
5150
]

GNNGraphs/docs/src/api/gnngraph.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,8 @@ Base.intersect
7777

7878
## Sampling
7979

80-
```@autodocs; canonical = true
81-
Modules = [GNNGraphs]
82-
Pages = ["src/sampling.jl"]
83-
Private = false
84-
```
85-
86-
```@docs; canonical = true
80+
```@docs
81+
NeighborLoader
82+
sample_neighbors
8783
Graphs.induced_subgraph(::GNNGraph, ::Vector{Int})
88-
```
84+
```

GNNGraphs/docs/src/api/samplers.md

Lines changed: 0 additions & 12 deletions
This file was deleted.

GNNGraphs/src/GNNGraphs.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using LinearAlgebra, Random, Statistics
1313
import MLUtils
1414
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like
1515
using MLDataDevices: get_device, cpu_device, CPUDevice
16+
using Functors: @functor
1617

1718
include("chainrules.jl") # hacks for differentiability
1819

@@ -54,7 +55,7 @@ export adjacency_list,
5455
normalized_laplacian,
5556
scaled_laplacian,
5657
laplacian_lambda_max,
57-
# from Graphs
58+
# from Graphs.jl
5859
adjacency_matrix,
5960
degree,
6061
has_edge,
@@ -82,7 +83,7 @@ export add_nodes,
8283
perturb_edges,
8384
remove_nodes,
8485
ppr_diffusion,
85-
# from MLUtils
86+
# from MLUtils.jl
8687
batch,
8788
unbatch,
8889
# from SparseArrays
@@ -98,9 +99,6 @@ export rand_graph,
9899
rand_temporal_radius_graph,
99100
rand_temporal_hyperbolic_graph
100101

101-
include("sampling.jl")
102-
export sample_neighbors
103-
104102
include("operators.jl")
105103
# Base.intersect
106104

@@ -117,7 +115,8 @@ export mldataset2gnngraph
117115

118116
include("deprecations.jl")
119117

120-
include("samplers.jl")
121-
export NeighborLoader
118+
include("sampling.jl")
119+
export NeighborLoader, sample_neighbors,
120+
induced_subgraph # from Graphs.jl
122121

123122
end #module

GNNGraphs/src/samplers.jl

Lines changed: 0 additions & 105 deletions
This file was deleted.

GNNGraphs/src/sampling.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,110 @@
1+
"""
2+
NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size])
3+
4+
A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs).
5+
It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training
6+
originally introduced in ["Inductive Representation Learning on Large Graphs"}(https://arxiv.org/abs/1706.02216) paper.
7+
8+
# Fields
9+
- `graph::GNNGraph`: The input graph.
10+
- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer.
11+
- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling.
12+
- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors).
13+
- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`.
14+
15+
# Examples
16+
17+
```julia
18+
julia> loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2)
19+
20+
julia> batch_counter = 0
21+
22+
julia> for mini_batch_gnn in loader
23+
batch_counter += 1
24+
println("Batch ", batch_counter, ": Nodes in mini-batch graph: ", nv(mini_batch_gnn))
25+
end
26+
```
27+
"""
28+
struct NeighborLoader
29+
graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl)
30+
num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer
31+
input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling)
32+
num_layers::Int # Number of layers for neighborhood expansion
33+
batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given
34+
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation
35+
end
36+
37+
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}=nothing,
38+
num_layers::Int, batch_size::Union{Int, Nothing}=nothing)
39+
return NeighborLoader(graph, num_neighbors, input_nodes === nothing ? collect(1:graph.num_nodes) : input_nodes, num_layers,
40+
batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}())
41+
end
42+
43+
# Function to get cached neighbors or compute them
44+
function get_neighbors(loader::NeighborLoader, node::Int)
45+
if haskey(loader.neighbors_cache, node)
46+
return loader.neighbors_cache[node]
47+
else
48+
neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph
49+
loader.neighbors_cache[node] = neighbors
50+
return neighbors
51+
end
52+
end
53+
54+
# Function to sample neighbors for a given node at a specific layer
55+
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)
56+
neighbors = get_neighbors(loader, node)
57+
if isempty(neighbors)
58+
return Int[]
59+
else
60+
num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer
61+
return rand(neighbors, num_samples) # Randomly sample neighbors
62+
end
63+
end
64+
65+
# Iterator protocol for NeighborLoader with lazy batch loading
66+
function Base.iterate(loader::NeighborLoader, state=1)
67+
if state > length(loader.input_nodes)
68+
return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number)
69+
end
70+
71+
# Determine the size of the current batch
72+
batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch
73+
batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes
74+
75+
# Set for tracking the subgraph nodes
76+
subgraph_nodes = Set(batch_nodes)
77+
78+
for node in batch_nodes
79+
# Initialize current layer of nodes (starting with the node itself)
80+
sampled_neighbors = Set([node])
81+
82+
# For each GNN layer, sample the neighborhood
83+
for layer in 1:loader.num_layers
84+
new_neighbors = Set{Int}()
85+
for n in sampled_neighbors
86+
neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer
87+
new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set
88+
end
89+
sampled_neighbors = new_neighbors
90+
subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors
91+
end
92+
end
93+
94+
# Collect subgraph nodes and their features
95+
subgraph_node_list = collect(subgraph_nodes)
96+
97+
if isempty(subgraph_node_list)
98+
return GNNGraph(), state + batch_size
99+
end
100+
101+
mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes
102+
103+
# Continue iteration for the next batch
104+
return mini_batch_gnn, state + batch_size
105+
end
106+
107+
1108
"""
2109
sample_neighbors(g, nodes, K=-1; dir=:in, replace=false, dropnodes=false)
3110

GNNGraphs/src/temporalsnapshotsgnngraph.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ struct TemporalSnapshotsGNNGraph{G<:GNNGraph, D<:DataStore}
6161
tgdata::D
6262
end
6363

64+
# do not move to gpu num_nodes and num_edges
65+
@functor TemporalSnapshotsGNNGraph (snapshots, tgdata)
66+
6467
function TemporalSnapshotsGNNGraph(snapshots)
6568
snapshots = collect(snapshots)
6669
return TemporalSnapshotsGNNGraph(

0 commit comments

Comments
 (0)