diff --git a/GNNGraphs/src/sampling.jl b/GNNGraphs/src/sampling.jl index 7e723182a..6e38730f0 100644 --- a/GNNGraphs/src/sampling.jl +++ b/GNNGraphs/src/sampling.jl @@ -177,6 +177,8 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) node_map = Dict(node => i for (i, node) in enumerate(nodes)) + edge_list = [collect(t) for t in zip(edge_index(graph)[1],edge_index(graph)[2])] + # Collect edges to add source = Int[] target = Int[] @@ -187,8 +189,7 @@ function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) if neighbor in keys(node_map) push!(target, node_map[node]) push!(source, node_map[neighbor]) - - eindex = findfirst(x -> x == [neighbor, node], edge_index(graph)) + eindex = findfirst(x -> x == [neighbor, node], edge_list) push!(eindices, eindex) end end diff --git a/GraphNeuralNetworks/Project.toml b/GraphNeuralNetworks/Project.toml index 89979ff69..5c6389479 100644 --- a/GraphNeuralNetworks/Project.toml +++ b/GraphNeuralNetworks/Project.toml @@ -9,6 +9,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -28,6 +29,7 @@ CUDA = "4, 5" ChainRulesCore = "1" Flux = "0.14" Functors = "0.4.1" +Graphs = "1.12" GNNGraphs = "1.0" GNNlib = "0.2" LinearAlgebra = "1" diff --git a/GraphNeuralNetworks/docs/make.jl b/GraphNeuralNetworks/docs/make.jl index 869aa94f1..8cb762166 100644 --- a/GraphNeuralNetworks/docs/make.jl +++ b/GraphNeuralNetworks/docs/make.jl @@ -46,6 +46,7 @@ makedocs(; "Message Passing" => "api/messagepassing.md", "Heterogeneous Graphs" => "api/heterograph.md", "Temporal Graphs" => "api/temporalgraph.md", + "Samplers" => "api/samplers.md", "Utils" => "api/utils.md", ], "Developer Notes" => "dev.md", diff --git a/GraphNeuralNetworks/docs/src/api/samplers.md b/GraphNeuralNetworks/docs/src/api/samplers.md new file mode 100644 index 000000000..f4285562c --- /dev/null +++ b/GraphNeuralNetworks/docs/src/api/samplers.md @@ -0,0 +1,14 @@ +```@meta +CurrentModule = GraphNeuralNetworks +``` + +# Samplers + + +## Docs + +```@autodocs +Modules = [GraphNeuralNetworks] +Pages = ["samplers.jl"] +Private = false +``` diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index c9a227b8d..9ac46e8b1 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -10,6 +10,7 @@ using NNlib: scatter, gather using ChainRulesCore using Reexport using MLUtils: zeros_like +using Graphs: Graphs using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, check_num_nodes, check_num_edges, @@ -66,4 +67,7 @@ export GlobalPool, include("deprecations.jl") +include("samplers.jl") +export NeighborLoader + end diff --git a/GraphNeuralNetworks/src/samplers.jl b/GraphNeuralNetworks/src/samplers.jl new file mode 100644 index 000000000..5c06c1681 --- /dev/null +++ b/GraphNeuralNetworks/src/samplers.jl @@ -0,0 +1,103 @@ +""" + NeighborLoader(graph; num_neighbors, input_nodes, num_layers, [batch_size]) + +A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs). +It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training +originally introduced in "Inductive Representation Learning on Large Graphs" paper. +[see https://arxiv.org/abs/1706.02216] + +# Fields +- `graph::GNNGraph`: The input graph. +- `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer. +- `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling. +- `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors). +- `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`. + +# Usage +```jldoctest +julia> loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2) + +julia> batch_counter = 0 +julia> for mini_batch_gnn in loader + batch_counter += 1 + println("Batch ", batch_counter, ": Nodes in mini-batch graph: ", nv(mini_batch_gnn)) +``` +""" +struct NeighborLoader + graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl) + num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer + input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling) + num_layers::Int # Number of layers for neighborhood expansion + batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given + neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation +end + +function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}=nothing, + num_layers::Int, batch_size::Union{Int, Nothing}=nothing) + return NeighborLoader(graph, num_neighbors, input_nodes === nothing ? collect(1:graph.num_nodes) : input_nodes, num_layers, + batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}()) +end + +# Function to get cached neighbors or compute them +function get_neighbors(loader::NeighborLoader, node::Int) + if haskey(loader.neighbors_cache, node) + return loader.neighbors_cache[node] + else + neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph + loader.neighbors_cache[node] = neighbors + return neighbors + end +end + +# Function to sample neighbors for a given node at a specific layer +function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) + neighbors = get_neighbors(loader, node) + if isempty(neighbors) + return Int[] + else + num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer + return rand(neighbors, num_samples) # Randomly sample neighbors + end +end + +# Iterator protocol for NeighborLoader with lazy batch loading +function Base.iterate(loader::NeighborLoader, state=1) + if state > length(loader.input_nodes) + return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number) + end + + # Determine the size of the current batch + batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch + batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes + + # Set for tracking the subgraph nodes + subgraph_nodes = Set(batch_nodes) + + for node in batch_nodes + # Initialize current layer of nodes (starting with the node itself) + sampled_neighbors = Set([node]) + + # For each GNN layer, sample the neighborhood + for layer in 1:loader.num_layers + new_neighbors = Set{Int}() + for n in sampled_neighbors + neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer + new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set + end + sampled_neighbors = new_neighbors + subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors + end + end + + # Collect subgraph nodes and their features + subgraph_node_list = collect(subgraph_nodes) + + if isempty(subgraph_node_list) + return GNNGraph(), state + batch_size + end + + mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes + + # Continue iteration for the next batch + return mini_batch_gnn, state + batch_size +end diff --git a/GraphNeuralNetworks/test/runtests.jl b/GraphNeuralNetworks/test/runtests.jl index 05cb6fd5f..f796651bb 100644 --- a/GraphNeuralNetworks/test/runtests.jl +++ b/GraphNeuralNetworks/test/runtests.jl @@ -30,6 +30,7 @@ tests = [ "layers/temporalconv", "layers/pool", "examples/node_classification_cora", + "samplers" ] !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") diff --git a/GraphNeuralNetworks/test/samplers.jl b/GraphNeuralNetworks/test/samplers.jl new file mode 100644 index 000000000..546291717 --- /dev/null +++ b/GraphNeuralNetworks/test/samplers.jl @@ -0,0 +1,125 @@ +# Helper function to create a simple graph with node features using GNNGraph +function create_test_graph() + source = [1, 2, 3, 4] # Define source nodes of edges + target = [2, 3, 4, 5] # Define target nodes of edges + node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes) + + return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features +end + +# Tests for NeighborLoader structure and its functionalities +@testset "NeighborLoader tests" begin + + # 1. Basic functionality: Check neighbor sampling and subgraph creation + @testset "Basic functionality" begin + g = create_test_graph() + + # Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2 + loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph is not empty + @test !isempty(mini_batch_gnn.graph) + + num_sampled_nodes = mini_batch_gnn.num_nodes + println("Number of nodes in mini-batch: ", num_sampled_nodes) + + @test num_sampled_nodes == 2 + + # Test if there are edges in the subgraph + @test mini_batch_gnn.num_edges > 0 + end + + # 2. Edge case: Single node with no neighbors + @testset "Single node with no neighbors" begin + g = SimpleDiGraph(1) # A graph with a single node and no edges + node_features = rand(Float32, 5, 1) + graph = GNNGraph(g, ndata = node_features) + + loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph contains only one node + @test size(mini_batch_gnn.x, 2) == 1 + end + + # 3. Edge case: A node with no outgoing edges (isolated node) + @testset "Node with no outgoing edges" begin + g = SimpleDiGraph(2) # Graph with 2 nodes, no edges + node_features = rand(Float32, 5, 2) + graph = GNNGraph(g, ndata = node_features) + + loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled) + @test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes + end + + # 4. Edge case: A fully connected graph + @testset "Fully connected graph" begin + g = SimpleDiGraph(3) + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + add_edge!(g, 3, 1) + node_features = rand(Float32, 5, 3) + graph = GNNGraph(g, ndata = node_features) + + loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if all nodes are included in the mini-batch since it's fully connected + @test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included + end + + # 5. Edge case: More layers than the number of neighbors + @testset "More layers than available neighbors" begin + g = SimpleDiGraph(3) + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + node_features = rand(Float32, 5, 3) + graph = GNNGraph(g, ndata = node_features) + + # Test with 3 layers but only enough connections for 2 layers + loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph contains all available nodes + @test size(mini_batch_gnn.x, 2) == 1 + end + + # 6. Edge case: Large batch size greater than the number of input nodes + @testset "Large batch size" begin + g = create_test_graph() + + # Define NeighborLoader with a larger batch size than input nodes + loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph is not empty + @test !isempty(mini_batch_gnn.graph) + + # Test if the correct number of nodes are sampled + @test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected + end + + # 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer + @testset "No neighbors sampled" begin + g = create_test_graph() + + # Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2 + loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2) + + mini_batch_gnn, next_state = iterate(loader) + + # Test if the mini-batch graph contains only the input nodes + @test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph + end + +end \ No newline at end of file