@@ -27,8 +27,10 @@ struct NeighborLoader
2727    neighbors_cache:: Dict{Int, Vector{Int}}   #  Cache neighbors to avoid recomputation
2828end 
2929
30- function  NeighborLoader (graph:: GNNGraph ; num_neighbors:: Vector{Int} , input_nodes:: Vector{Int} , num_layers:: Int , batch_size:: Union{Int, Nothing} = nothing )
31-     return  NeighborLoader (graph, num_neighbors, input_nodes, num_layers, batch_size ===  nothing  ?  length (input_nodes) :  batch_size, Dict {Int, Vector{Int}} ())
30+ function  NeighborLoader (graph:: GNNGraph ; num_neighbors:: Vector{Int} , input_nodes:: Vector{Int} = nothing , 
31+                         num_layers:: Int , batch_size:: Union{Int, Nothing} = nothing )
32+     return  NeighborLoader (graph, num_neighbors, input_nodes ===  nothing  ?  collect (1 : graph. num_nodes) :  input_nodes, num_layers, 
33+                             batch_size ===  nothing  ?  length (input_nodes) :  batch_size, Dict {Int, Vector{Int}} ())
3234end 
3335
3436#  Function to get cached neighbors or compute them
@@ -42,20 +44,6 @@ function get_neighbors(loader::NeighborLoader, node::Int)
4244    end 
4345end 
4446
45- """ 
46-     sample_nbrs(loader::NeighborLoader, node::Int, layer::Int) 
47- 
48- Samples a specified number of neighbors for the given `node` at a particular `layer` of the GNN.  
49-     The number of neighbors sampled is defined in `loader.num_neighbors`. 
50- 
51- # Arguments: 
52- - `loader::NeighborLoader`: The `NeighborLoader` instance. 
53- - `node::Int`: The node to sample neighbors for. 
54- - `layer::Int`: The current GNN layer (used to determine how many neighbors to sample). 
55- 
56- # Returns: 
57- A vector of sampled neighbor node indices. 
58- """ 
5947#  Function to sample neighbors for a given node at a specific layer
6048function  sample_nbrs (loader:: NeighborLoader , node:: Int , layer:: Int )
6149    neighbors =  get_neighbors (loader, node)
0 commit comments