-
Couldn't load subscription status.
- Fork 56
feat: add NeighborLoader #497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
349d99f
feat: init neighbor loader
askorupka 07db4d0
feat: init neighbor loader
askorupka 25945c7
feat: refine neighborloader
askorupka 506d4c7
fix: refine neighborloader
askorupka 991bf61
fix: refine neighborloader
askorupka c25bc1e
fix: refine neighborloader
askorupka fde10bb
fix: refine neighborloader
askorupka 3656691
fix: refine neighborloader
askorupka 9997fab
chore: add some comments
askorupka acf209c
chore: add TODO comments
askorupka 0c4a653
feat: add tests, refine code
askorupka 2035c5e
fix: add samplers.jl after rebase
askorupka ebebce9
chore: add docstrings
askorupka abf31cd
chore: Graphs to deps
askorupka bcdfa5e
chore: move using Graphs to main file
askorupka 970d297
chore: readd Graphs to extras
askorupka b4c1ad7
chore: delete src/samplers.jl created by mistake
askorupka 5e7544c
fix: add sampling.jl to docs
askorupka c9d412b
fix: add sampling.jl to docs
askorupka 2d7bd0b
fix: add sampling.jl to docs
askorupka 65aa564
fix: deduplicate function
askorupka 61c5e39
fix: fix broken tests
askorupka aec5574
chore: remove printlns
askorupka e675086
Update GraphNeuralNetworks/src/samplers.jl
askorupka 62f5d87
fix: remove docstrings where not needed
askorupka 3ed22bf
chore: add ref to the paper
askorupka e4dc977
Update GraphNeuralNetworks/src/samplers.jl
askorupka 962a97f
Update GraphNeuralNetworks/src/GraphNeuralNetworks.jl
askorupka d552de4
Update GraphNeuralNetworks/src/samplers.jl
askorupka 6f26713
chore: add compat for Graphs
askorupka a4b6e15
refactor: allow input_nodes to be nothing
askorupka 9af384f
chore: add loader iterate example to docstring
askorupka aa18520
fix: fix tests (docstring error)
askorupka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| ```@meta | ||
| CurrentModule = GraphNeuralNetworks | ||
| ``` | ||
|
|
||
| # Samplers | ||
|
|
||
|
|
||
| ## Docs | ||
|
|
||
| ```@autodocs | ||
| Modules = [GraphNeuralNetworks] | ||
| Pages = ["samplers.jl"] | ||
| Private = false | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| """ | ||
| struct NeighborLoader | ||
askorupka marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| A data structure for sampling neighbors from a graph for training Graph Neural Networks (GNNs). | ||
| It supports multi-layer sampling of neighbors for a batch of input nodes, useful for mini-batch training | ||
| originally introduced in "Inductive Representation Learning on Large Graphs" paper. | ||
| [see https://arxiv.org/abs/1706.02216] | ||
| # Fields | ||
| - `graph::GNNGraph`: The input graph. | ||
| - `num_neighbors::Vector{Int}`: A vector specifying the number of neighbors to sample per node at each GNN layer. | ||
| - `input_nodes::Vector{Int}`: A vector containing the starting nodes for neighbor sampling. | ||
| - `num_layers::Int`: The number of layers for neighborhood expansion (how far to sample neighbors). | ||
| - `batch_size::Union{Int, Nothing}`: The size of the batch. If not specified, it defaults to the number of `input_nodes`. | ||
askorupka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Usage | ||
| ```julia | ||
| loader = NeighborLoader(graph; num_neighbors=[10, 5], input_nodes=[1, 2, 3], num_layers=2) | ||
| ``` | ||
askorupka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| struct NeighborLoader | ||
| graph::GNNGraph # The input GNNGraph (graph + features from GraphNeuralNetworks.jl) | ||
| num_neighbors::Vector{Int} # Number of neighbors to sample per node, for each layer | ||
| input_nodes::Vector{Int} # Set of input nodes (starting nodes for sampling) | ||
| num_layers::Int # Number of layers for neighborhood expansion | ||
| batch_size::Union{Int, Nothing} # Optional batch size, defaults to the length of input_nodes if not given | ||
| neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation | ||
| end | ||
|
|
||
| function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing) | ||
askorupka marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return NeighborLoader(graph, num_neighbors, input_nodes, num_layers, batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}()) | ||
askorupka marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| end | ||
|
|
||
| # Function to get cached neighbors or compute them | ||
| function get_neighbors(loader::NeighborLoader, node::Int) | ||
| if haskey(loader.neighbors_cache, node) | ||
| return loader.neighbors_cache[node] | ||
| else | ||
| neighbors = Graphs.neighbors(loader.graph, node, dir = :in) # Get neighbors from graph | ||
| loader.neighbors_cache[node] = neighbors | ||
| return neighbors | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) | ||
| Samples a specified number of neighbors for the given `node` at a particular `layer` of the GNN. | ||
| The number of neighbors sampled is defined in `loader.num_neighbors`. | ||
| # Arguments: | ||
| - `loader::NeighborLoader`: The `NeighborLoader` instance. | ||
| - `node::Int`: The node to sample neighbors for. | ||
| - `layer::Int`: The current GNN layer (used to determine how many neighbors to sample). | ||
| # Returns: | ||
| A vector of sampled neighbor node indices. | ||
| """ | ||
askorupka marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Function to sample neighbors for a given node at a specific layer | ||
| function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) | ||
| neighbors = get_neighbors(loader, node) | ||
| if isempty(neighbors) | ||
| return Int[] | ||
| else | ||
| num_samples = min(loader.num_neighbors[layer], length(neighbors)) # Limit to required samples for this layer | ||
| return rand(neighbors, num_samples) # Randomly sample neighbors | ||
| end | ||
| end | ||
|
|
||
| # Iterator protocol for NeighborLoader with lazy batch loading | ||
| function Base.iterate(loader::NeighborLoader, state=1) | ||
| if state > length(loader.input_nodes) | ||
| return nothing # End of iteration if batches are exhausted (state larger than amount of input nodes or current batch no >= batch number) | ||
| end | ||
|
|
||
| # Determine the size of the current batch | ||
| batch_size = min(loader.batch_size, length(loader.input_nodes) - state + 1) # Conditional in case there is not enough nodes to fill the last batch | ||
| batch_nodes = loader.input_nodes[state:state + batch_size - 1] # Each mini-batch uses different set of input nodes | ||
|
|
||
| # Set for tracking the subgraph nodes | ||
| subgraph_nodes = Set(batch_nodes) | ||
|
|
||
| for node in batch_nodes | ||
| # Initialize current layer of nodes (starting with the node itself) | ||
| sampled_neighbors = Set([node]) | ||
|
|
||
| # For each GNN layer, sample the neighborhood | ||
| for layer in 1:loader.num_layers | ||
| new_neighbors = Set{Int}() | ||
| for n in sampled_neighbors | ||
| neighbors = sample_nbrs(loader, n, layer) # Sample neighbors of the node for this layer | ||
| new_neighbors = union(new_neighbors, neighbors) # Avoid duplicates in the neighbor set | ||
| end | ||
| sampled_neighbors = new_neighbors | ||
| subgraph_nodes = union(subgraph_nodes, sampled_neighbors) # Expand the subgraph with the new neighbors | ||
| end | ||
| end | ||
|
|
||
| # Collect subgraph nodes and their features | ||
| subgraph_node_list = collect(subgraph_nodes) | ||
|
|
||
| if isempty(subgraph_node_list) | ||
| return GNNGraph(), state + batch_size | ||
| end | ||
|
|
||
| mini_batch_gnn = Graphs.induced_subgraph(loader.graph, subgraph_node_list) # Create a subgraph of the nodes | ||
|
|
||
| # Continue iteration for the next batch | ||
| return mini_batch_gnn, state + batch_size | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| # Helper function to create a simple graph with node features using GNNGraph | ||
| function create_test_graph() | ||
| source = [1, 2, 3, 4] # Define source nodes of edges | ||
| target = [2, 3, 4, 5] # Define target nodes of edges | ||
| node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes) | ||
|
|
||
| return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features | ||
| end | ||
|
|
||
| # Tests for NeighborLoader structure and its functionalities | ||
| @testset "NeighborLoader tests" begin | ||
|
|
||
| # 1. Basic functionality: Check neighbor sampling and subgraph creation | ||
| @testset "Basic functionality" begin | ||
| g = create_test_graph() | ||
|
|
||
| # Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2 | ||
| loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2) | ||
|
|
||
| mini_batch_gnn, next_state = iterate(loader) | ||
|
|
||
| # Test if the mini-batch graph is not empty | ||
| @test !isempty(mini_batch_gnn.graph) | ||
|
|
||
| num_sampled_nodes = mini_batch_gnn.num_nodes | ||
| println("Number of nodes in mini-batch: ", num_sampled_nodes) | ||
|
|
||
| @test num_sampled_nodes == 2 | ||
|
|
||
| # Test if there are edges in the subgraph | ||
| @test mini_batch_gnn.num_edges > 0 | ||
| end | ||
|
|
||
| # 2. Edge case: Single node with no neighbors | ||
| @testset "Single node with no neighbors" begin | ||
| g = SimpleDiGraph(1) # A graph with a single node and no edges | ||
| node_features = rand(Float32, 5, 1) | ||
| graph = GNNGraph(g, ndata = node_features) | ||
|
|
||
| loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1) | ||
|
|
||
| mini_batch_gnn, next_state = iterate(loader) | ||
|
|
||
| # Test if the mini-batch graph contains only one node | ||
| @test size(mini_batch_gnn.x, 2) == 1 | ||
| end | ||
|
|
||
| # 3. Edge case: A node with no outgoing edges (isolated node) | ||
| @testset "Node with no outgoing edges" begin | ||
| g = SimpleDiGraph(2) # Graph with 2 nodes, no edges | ||
| node_features = rand(Float32, 5, 2) | ||
| graph = GNNGraph(g, ndata = node_features) | ||
|
|
||
| loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1) | ||
|
|
||
| mini_batch_gnn, next_state = iterate(loader) | ||
|
|
||
| # Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled) | ||
| @test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes | ||
| end | ||
|
|
||
| # 4. Edge case: A fully connected graph | ||
| @testset "Fully connected graph" begin | ||
| g = SimpleDiGraph(3) | ||
| add_edge!(g, 1, 2) | ||
| add_edge!(g, 2, 3) | ||
| add_edge!(g, 3, 1) | ||
| node_features = rand(Float32, 5, 3) | ||
| graph = GNNGraph(g, ndata = node_features) | ||
|
|
||
| loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2) | ||
|
|
||
| mini_batch_gnn, next_state = iterate(loader) | ||
|
|
||
| # Test if all nodes are included in the mini-batch since it's fully connected | ||
| @test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included | ||
| end | ||
|
|
||
| # 5. Edge case: More layers than the number of neighbors | ||
| @testset "More layers than available neighbors" begin | ||
| g = SimpleDiGraph(3) | ||
| add_edge!(g, 1, 2) | ||
| add_edge!(g, 2, 3) | ||
| node_features = rand(Float32, 5, 3) | ||
| graph = GNNGraph(g, ndata = node_features) | ||
|
|
||
| # Test with 3 layers but only enough connections for 2 layers | ||
| loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3) | ||
|
|
||
| mini_batch_gnn, next_state = iterate(loader) | ||
|
|
||
| # Test if the mini-batch graph contains all available nodes | ||
| @test size(mini_batch_gnn.x, 2) == 1 | ||
| end | ||
|
|
||
| # 6. Edge case: Large batch size greater than the number of input nodes | ||
| @testset "Large batch size" begin | ||
| g = create_test_graph() | ||
|
|
||
| # Define NeighborLoader with a larger batch size than input nodes | ||
| loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10) | ||
|
|
||
| mini_batch_gnn, next_state = iterate(loader) | ||
|
|
||
| # Test if the mini-batch graph is not empty | ||
| @test !isempty(mini_batch_gnn.graph) | ||
|
|
||
| # Test if the correct number of nodes are sampled | ||
| @test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected | ||
| end | ||
|
|
||
| # 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer | ||
| @testset "No neighbors sampled" begin | ||
| g = create_test_graph() | ||
|
|
||
| # Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2 | ||
| loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2) | ||
|
|
||
| mini_batch_gnn, next_state = iterate(loader) | ||
|
|
||
| # Test if the mini-batch graph contains only the input nodes | ||
| @test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph | ||
| end | ||
|
|
||
| end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.