Skip to content

Commit 6c17e20

Browse files
committed
Merge branch 'master' into ar/move-neighborloader
2 parents 552c8f2 + 530457c commit 6c17e20

File tree

10 files changed

+568
-598
lines changed

10 files changed

+568
-598
lines changed

GNNGraphs/test/samplers.jl

Lines changed: 91 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,126 @@
1-
# Helper function to create a simple graph with node features using GNNGraph
2-
function create_test_graph()
3-
source = [1, 2, 3, 4] # Define source nodes of edges
4-
target = [2, 3, 4, 5] # Define target nodes of edges
5-
node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes)
1+
#TODO reactivate test
2+
# @testitem "NeighborLoader" setup=[TestModule] begin
3+
# using .TestModule
4+
# # Helper function to create a simple graph with node features using GNNGraph
5+
# function create_test_graph()
6+
# source = [1, 2, 3, 4] # Define source nodes of edges
7+
# target = [2, 3, 4, 5] # Define target nodes of edges
8+
# node_features = rand(Float32, 5, 5) # Create random node features (5 features for 5 nodes)
69

7-
return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features
8-
end
10+
# return GNNGraph(source, target, ndata = node_features) # Create a GNNGraph with edges and features
11+
# end
912

10-
# Tests for NeighborLoader structure and its functionalities
11-
@testset "NeighborLoader tests" begin
1213

13-
# 1. Basic functionality: Check neighbor sampling and subgraph creation
14-
@testset "Basic functionality" begin
15-
g = create_test_graph()
14+
# # 1. Basic functionality: Check neighbor sampling and subgraph creation
15+
# @testset "Basic functionality" begin
16+
# g = create_test_graph()
1617

17-
# Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2
18-
loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2)
18+
# # Define NeighborLoader with 2 neighbors per layer, 2 layers, batch size 2
19+
# loader = NeighborLoader(g; num_neighbors=[2, 2], input_nodes=[1, 2], num_layers=2, batch_size=2)
1920

20-
mini_batch_gnn, next_state = iterate(loader)
21+
# mini_batch_gnn, next_state = iterate(loader)
2122

22-
# Test if the mini-batch graph is not empty
23-
@test !isempty(mini_batch_gnn.graph)
23+
# # Test if the mini-batch graph is not empty
24+
# @test !isempty(mini_batch_gnn.graph)
2425

25-
num_sampled_nodes = mini_batch_gnn.num_nodes
26-
println("Number of nodes in mini-batch: ", num_sampled_nodes)
26+
# num_sampled_nodes = mini_batch_gnn.num_nodes
27+
# println("Number of nodes in mini-batch: ", num_sampled_nodes)
2728

28-
@test num_sampled_nodes == 2
29+
# @test num_sampled_nodes == 2
2930

30-
# Test if there are edges in the subgraph
31-
@test mini_batch_gnn.num_edges > 0
32-
end
31+
# # Test if there are edges in the subgraph
32+
# @test mini_batch_gnn.num_edges > 0
33+
# end
3334

34-
# 2. Edge case: Single node with no neighbors
35-
@testset "Single node with no neighbors" begin
36-
g = SimpleDiGraph(1) # A graph with a single node and no edges
37-
node_features = rand(Float32, 5, 1)
38-
graph = GNNGraph(g, ndata = node_features)
35+
# # 2. Edge case: Single node with no neighbors
36+
# @testset "Single node with no neighbors" begin
37+
# g = SimpleDiGraph(1) # A graph with a single node and no edges
38+
# node_features = rand(Float32, 5, 1)
39+
# graph = GNNGraph(g, ndata = node_features)
3940

40-
loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1)
41+
# loader = NeighborLoader(graph; num_neighbors=[2], input_nodes=[1], num_layers=1)
4142

42-
mini_batch_gnn, next_state = iterate(loader)
43+
# mini_batch_gnn, next_state = iterate(loader)
4344

44-
# Test if the mini-batch graph contains only one node
45-
@test size(mini_batch_gnn.x, 2) == 1
46-
end
45+
# # Test if the mini-batch graph contains only one node
46+
# @test size(mini_batch_gnn.x, 2) == 1
47+
# end
4748

48-
# 3. Edge case: A node with no outgoing edges (isolated node)
49-
@testset "Node with no outgoing edges" begin
50-
g = SimpleDiGraph(2) # Graph with 2 nodes, no edges
51-
node_features = rand(Float32, 5, 2)
52-
graph = GNNGraph(g, ndata = node_features)
49+
# # 3. Edge case: A node with no outgoing edges (isolated node)
50+
# @testset "Node with no outgoing edges" begin
51+
# g = SimpleDiGraph(2) # Graph with 2 nodes, no edges
52+
# node_features = rand(Float32, 5, 2)
53+
# graph = GNNGraph(g, ndata = node_features)
5354

54-
loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1)
55+
# loader = NeighborLoader(graph; num_neighbors=[1], input_nodes=[1, 2], num_layers=1)
5556

56-
mini_batch_gnn, next_state = iterate(loader)
57+
# mini_batch_gnn, next_state = iterate(loader)
5758

58-
# Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled)
59-
@test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes
60-
end
59+
# # Test if the mini-batch graph contains the input nodes only (as no neighbors can be sampled)
60+
# @test size(mini_batch_gnn.x, 2) == 2 # Only two isolated nodes
61+
# end
6162

62-
# 4. Edge case: A fully connected graph
63-
@testset "Fully connected graph" begin
64-
g = SimpleDiGraph(3)
65-
add_edge!(g, 1, 2)
66-
add_edge!(g, 2, 3)
67-
add_edge!(g, 3, 1)
68-
node_features = rand(Float32, 5, 3)
69-
graph = GNNGraph(g, ndata = node_features)
63+
# # 4. Edge case: A fully connected graph
64+
# @testset "Fully connected graph" begin
65+
# g = SimpleDiGraph(3)
66+
# add_edge!(g, 1, 2)
67+
# add_edge!(g, 2, 3)
68+
# add_edge!(g, 3, 1)
69+
# node_features = rand(Float32, 5, 3)
70+
# graph = GNNGraph(g, ndata = node_features)
7071

71-
loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2)
72+
# loader = NeighborLoader(graph; num_neighbors=[2, 2], input_nodes=[1], num_layers=2)
7273

73-
mini_batch_gnn, next_state = iterate(loader)
74+
# mini_batch_gnn, next_state = iterate(loader)
7475

75-
# Test if all nodes are included in the mini-batch since it's fully connected
76-
@test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included
77-
end
76+
# # Test if all nodes are included in the mini-batch since it's fully connected
77+
# @test size(mini_batch_gnn.x, 2) == 3 # All nodes should be included
78+
# end
7879

79-
# 5. Edge case: More layers than the number of neighbors
80-
@testset "More layers than available neighbors" begin
81-
g = SimpleDiGraph(3)
82-
add_edge!(g, 1, 2)
83-
add_edge!(g, 2, 3)
84-
node_features = rand(Float32, 5, 3)
85-
graph = GNNGraph(g, ndata = node_features)
80+
# # 5. Edge case: More layers than the number of neighbors
81+
# @testset "More layers than available neighbors" begin
82+
# g = SimpleDiGraph(3)
83+
# add_edge!(g, 1, 2)
84+
# add_edge!(g, 2, 3)
85+
# node_features = rand(Float32, 5, 3)
86+
# graph = GNNGraph(g, ndata = node_features)
8687

87-
# Test with 3 layers but only enough connections for 2 layers
88-
loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3)
88+
# # Test with 3 layers but only enough connections for 2 layers
89+
# loader = NeighborLoader(graph; num_neighbors=[1, 1, 1], input_nodes=[1], num_layers=3)
8990

90-
mini_batch_gnn, next_state = iterate(loader)
91+
# mini_batch_gnn, next_state = iterate(loader)
9192

92-
# Test if the mini-batch graph contains all available nodes
93-
@test size(mini_batch_gnn.x, 2) == 1
94-
end
93+
# # Test if the mini-batch graph contains all available nodes
94+
# @test size(mini_batch_gnn.x, 2) == 1
95+
# end
9596

96-
# 6. Edge case: Large batch size greater than the number of input nodes
97-
@testset "Large batch size" begin
98-
g = create_test_graph()
97+
# # 6. Edge case: Large batch size greater than the number of input nodes
98+
# @testset "Large batch size" begin
99+
# g = create_test_graph()
99100

100-
# Define NeighborLoader with a larger batch size than input nodes
101-
loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10)
101+
# # Define NeighborLoader with a larger batch size than input nodes
102+
# loader = NeighborLoader(g; num_neighbors=[2], input_nodes=[1, 2], num_layers=1, batch_size=10)
102103

103-
mini_batch_gnn, next_state = iterate(loader)
104+
# mini_batch_gnn, next_state = iterate(loader)
104105

105-
# Test if the mini-batch graph is not empty
106-
@test !isempty(mini_batch_gnn.graph)
106+
# # Test if the mini-batch graph is not empty
107+
# @test !isempty(mini_batch_gnn.graph)
107108

108-
# Test if the correct number of nodes are sampled
109-
@test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected
110-
end
109+
# # Test if the correct number of nodes are sampled
110+
# @test size(mini_batch_gnn.x, 2) == length(unique([1, 2])) # Nodes [1, 2] are expected
111+
# end
111112

112-
# 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer
113-
@testset "No neighbors sampled" begin
114-
g = create_test_graph()
113+
# # 7. Edge case: No neighbors sampled (num_neighbors = [0]) and 1 layer
114+
# @testset "No neighbors sampled" begin
115+
# g = create_test_graph()
115116

116-
# Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2
117-
loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2)
117+
# # Define NeighborLoader with 0 neighbors per layer, 1 layer, batch size 2
118+
# loader = NeighborLoader(g; num_neighbors=[0], input_nodes=[1, 2], num_layers=1, batch_size=2)
118119

119-
mini_batch_gnn, next_state = iterate(loader)
120+
# mini_batch_gnn, next_state = iterate(loader)
120121

121-
# Test if the mini-batch graph contains only the input nodes
122-
@test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph
123-
end
122+
# # Test if the mini-batch graph contains only the input nodes
123+
# @test size(mini_batch_gnn.x, 2) == 2 # No neighbors should be sampled, only nodes 1 and 2 should be in the graph
124+
# end
124125

125-
end
126+
# end

GraphNeuralNetworks/Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,22 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2121
[weakdeps]
2222
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2323

24-
# [extensions]
25-
# GraphNeuralNetworksCUDAExt = "CUDA"
26-
2724
[compat]
2825
CUDA = "4, 5"
2926
ChainRulesCore = "1"
3027
Flux = "0.14"
3128
Functors = "0.4.1"
32-
Graphs = "1.12"
3329
GNNGraphs = "1.0"
3430
GNNlib = "0.2"
31+
Graphs = "1.12"
3532
LinearAlgebra = "1"
3633
MLUtils = "0.4"
3734
MacroTools = "0.5"
3835
NNlib = "0.9"
3936
Random = "1"
4037
Reexport = "1"
4138
Statistics = "1"
39+
TestItemRunner = "1.0.5"
4240
cuDNN = "1"
4341
julia = "1.10"
4442

@@ -53,8 +51,10 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
5351
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
5452
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5553
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
54+
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
5655
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5756
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
5857

5958
[targets]
60-
test = ["Test", "MLDatasets", "Adapt", "DataFrames", "InlineStrings", "SparseArrays", "Graphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "CUDA", "cuDNN"]
59+
test = ["Test", "TestItemRunner", "MLDatasets", "Adapt", "DataFrames", "InlineStrings",
60+
"SparseArrays", "Graphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "CUDA", "cuDNN"]

0 commit comments

Comments
 (0)