1+ # Helper function to create a simple graph with node features using GNNGraph
2+ function create_test_graph ()
3+ source = [1 , 2 , 3 , 4 ] # Define source nodes of edges
4+ target = [2 , 3 , 4 , 5 ] # Define target nodes of edges
5+ node_features = rand (Float32, 5 , 5 ) # Create random node features (5 features for 5 nodes)
6+
7+ return GNNGraph (source, target, ndata = node_features) # Create a GNNGraph with edges and features
8+ end
9+
10+ # Tests for NeighborLoader structure and its functionalities
11+ @testset " NeighborLoader tests" begin
12+
13+ # 1. Basic functionality: Check neighbor sampling and subgraph creation
14+ @testset " Basic functionality" begin
15+ g = create_test_graph ()
16+
17+ # Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2
18+ loader = NeighborLoader (g; num_neighbors= [2 , 2 ], input_nodes= [1 , 2 ], num_layers= 2 , batch_size= 2 )
19+
20+ mini_batch_gnn, next_state = iterate (loader)
21+
22+ # Test if the mini-batch graph is not empty
23+ @test ! isempty (mini_batch_gnn. graph)
24+
25+ num_sampled_nodes = mini_batch_gnn. num_nodes
26+ println (" Number of nodes in mini-batch: " , num_sampled_nodes)
27+
28+ @test num_sampled_nodes == 2
29+
30+ # Test if there are edges in the subgraph
31+ @test mini_batch_gnn. num_edges > 0
32+ end
33+
34+ # 2. Edge case: Single node with no neighbors
35+ @testset " Single node with no neighbors" begin
36+ g = SimpleDiGraph (1 ) # A graph with a single node and no edges
37+ node_features = rand (Float32, 5 , 1 )
38+ graph = GNNGraph (g, ndata = node_features)
39+
40+ loader = NeighborLoader (graph; num_neighbors= [2 ], input_nodes= [1 ], num_layers= 1 )
41+
42+ mini_batch_gnn, next_state = iterate (loader)
43+
44+ # Test if the mini-batch graph contains only one node
45+ @test size (mini_batch_gnn. x, 2 ) == 1
46+ end
47+
48+ # 3. Edge case: A node with no outgoing edges (isolated node)
49+ @testset " Node with no outgoing edges" begin
50+ g = SimpleDiGraph (2 ) # Graph with 2 nodes, no edges
51+ node_features = rand (Float32, 5 , 2 )
52+ graph = GNNGraph (g, ndata = node_features)
53+
54+ loader = NeighborLoader (graph; num_neighbors= [1 ], input_nodes= [1 , 2 ], num_layers= 1 )
55+
56+ mini_batch_gnn, next_state = iterate (loader)
57+
58+ # Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled)
59+ @test size (mini_batch_gnn. x, 2 ) == 2 # Only two isolated nodes
60+ end
61+
62+ # 4. Edge case: A fully connected graph
63+ @testset " Fully connected graph" begin
64+ g = SimpleDiGraph (3 )
65+ add_edge! (g, 1 , 2 )
66+ add_edge! (g, 2 , 3 )
67+ add_edge! (g, 3 , 1 )
68+ node_features = rand (Float32, 5 , 3 )
69+ graph = GNNGraph (g, ndata = node_features)
70+
71+ loader = NeighborLoader (graph; num_neighbors= [2 , 2 ], input_nodes= [1 ], num_layers= 2 )
72+
73+ mini_batch_gnn, next_state = iterate (loader)
74+
75+ # Test if all nodes are included in the mini-batch since it's fully connected
76+ @test size (mini_batch_gnn. x, 2 ) == 3 # All nodes should be included
77+ end
78+
79+ # 5. Edge case: More layers than the number of neighbors
80+ @testset " More layers than available neighbors" begin
81+ g = SimpleDiGraph (3 )
82+ add_edge! (g, 1 , 2 )
83+ add_edge! (g, 2 , 3 )
84+ node_features = rand (Float32, 5 , 3 )
85+ graph = GNNGraph (g, ndata = node_features)
86+
87+ # Test with 3 layers but only enough connections for 2 layers
88+ loader = NeighborLoader (graph; num_neighbors= [1 , 1 , 1 ], input_nodes= [1 ], num_layers= 3 )
89+
90+ mini_batch_gnn, next_state = iterate (loader)
91+
92+ # Test if the mini-batch graph contains all available nodes
93+ @test size (mini_batch_gnn. x, 2 ) == 1
94+ end
95+
96+ # 6. Edge case: Large batch size greater than the number of input nodes
97+ @testset " Large batch size" begin
98+ g = create_test_graph ()
99+
100+ # Define NeighborLoader with a larger batch size than input nodes
101+ loader = NeighborLoader (g; num_neighbors= [2 ], input_nodes= [1 , 2 ], num_layers= 1 , batch_size= 10 )
102+
103+ mini_batch_gnn, next_state = iterate (loader)
104+
105+ # Test if the mini-batch graph is not empty
106+ @test ! isempty (mini_batch_gnn. graph)
107+
108+ # Test if the correct number of nodes are sampled
109+ @test size (mini_batch_gnn. x, 2 ) == length (unique ([1 , 2 ])) # Nodes [1, 2] are expected
110+ end
111+
112+ # 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer
113+ @testset " No neighbors sampled" begin
114+ g = create_test_graph ()
115+
116+ # Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2
117+ loader = NeighborLoader (g; num_neighbors= [0 ], input_nodes= [1 , 2 ], num_layers= 1 , batch_size= 2 )
118+
119+ mini_batch_gnn, next_state = iterate (loader)
120+
121+ # Test if the mini-batch graph contains only the input nodes
122+ @test size (mini_batch_gnn. x, 2 ) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph
123+ end
124+
125+ end
0 commit comments