Skip to content

Commit a4b6e15

Browse files
committed
refactor: allow input_nodes to be nothing
1 parent 6f26713 commit a4b6e15

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

GraphNeuralNetworks/src/samplers.jl

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ struct NeighborLoader
2727
neighbors_cache::Dict{Int, Vector{Int}} # Cache neighbors to avoid recomputation
2828
end
2929

30-
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}, num_layers::Int, batch_size::Union{Int, Nothing}=nothing)
31-
return NeighborLoader(graph, num_neighbors, input_nodes, num_layers, batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}())
30+
function NeighborLoader(graph::GNNGraph; num_neighbors::Vector{Int}, input_nodes::Vector{Int}=nothing,
31+
num_layers::Int, batch_size::Union{Int, Nothing}=nothing)
32+
return NeighborLoader(graph, num_neighbors, input_nodes === nothing ? collect(1:graph.num_nodes) : input_nodes, num_layers,
33+
batch_size === nothing ? length(input_nodes) : batch_size, Dict{Int, Vector{Int}}())
3234
end
3335

3436
# Function to get cached neighbors or compute them
@@ -42,20 +44,6 @@ function get_neighbors(loader::NeighborLoader, node::Int)
4244
end
4345
end
4446

45-
"""
46-
sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)
47-
48-
Samples a specified number of neighbors for the given `node` at a particular `layer` of the GNN.
49-
The number of neighbors sampled is defined in `loader.num_neighbors`.
50-
51-
# Arguments:
52-
- `loader::NeighborLoader`: The `NeighborLoader` instance.
53-
- `node::Int`: The node to sample neighbors for.
54-
- `layer::Int`: The current GNN layer (used to determine how many neighbors to sample).
55-
56-
# Returns:
57-
A vector of sampled neighbor node indices.
58-
"""
5947
# Function to sample neighbors for a given node at a specific layer
6048
function sample_nbrs(loader::NeighborLoader, node::Int, layer::Int)
6149
neighbors = get_neighbors(loader, node)

0 commit comments

Comments
 (0)