1- RTOL_LOW =  1e-2 
2- RTOL_HIGH =  1e-5 
3- ATOL_LOW =  1e-3 
1+ @testsnippet  TolSnippet begin 
2+     RTOL_LOW =  1e-2 
3+     RTOL_HIGH =  1e-5 
4+     ATOL_LOW =  1e-3 
5+ end 
46
5- @testset  " GCNConv" begin 
7+ @testitem  " GCNConv" = [TolSnippet, TestModule] begin 
8+     using  . TestModule
69    l =  GCNConv (D_IN =>  D_OUT)
710    for  g in  TEST_GRAPHS
811        test_layer (l, g, rtol =  RTOL_HIGH, outsize =  (D_OUT, g. num_nodes))
@@ -16,7 +19,7 @@ ATOL_LOW = 1e-3
1619    l =  GCNConv (D_IN =>  D_OUT, add_self_loops =  false )
1720    test_layer (l, TEST_GRAPHS[1 ], rtol =  RTOL_HIGH, outsize =  (D_OUT, TEST_GRAPHS[1 ]. num_nodes))
1821
19-     @testset  " edge weights & custom normalization"   begin 
22+     @testset  " edge weights & custom normalization  $GRAPH_T "   for  GRAPH_T  in  GRAPH_TYPES 
2023        s =  [2 , 3 , 1 , 3 , 1 , 2 ]
2124        t =  [1 , 1 , 2 , 2 , 3 , 3 ]
2225        w =  Float32[1 , 2 , 3 , 4 , 5 , 6 ]
@@ -41,7 +44,7 @@ ATOL_LOW = 1e-3
4144    end 
4245
4346    @testset  " conv_weight" begin 
44-           l =  GraphNeuralNetworks. GCNConv (D_IN =>  D_OUT)
47+         l =  GraphNeuralNetworks. GCNConv (D_IN =>  D_OUT)
4548        w =  zeros (Float32, D_OUT, D_IN)
4649        g1 =  GNNGraph (TEST_GRAPHS[1 ], ndata =  ones (Float32, D_IN, 4 ))
4750        @test  l (g1, g1. ndata. x, conv_weight =  w) ==  zeros (Float32, D_OUT, 4 )
@@ -51,16 +54,16 @@ ATOL_LOW = 1e-3
5154    end 
5255end 
5356
54- @testset  " ChebConv" begin 
57+ @testitem  " ChebConv" = [TolSnippet, TestModule] begin 
58+     using  . TestModule
5559    k =  2 
5660    l =  ChebConv (D_IN =>  D_OUT, k)
5761    @test  size (l. weight) ==  (D_OUT, D_IN, k)
5862    @test  size (l. bias) ==  (D_OUT,)
5963    @test  l. k ==  k
6064    for  g in  TEST_GRAPHS
6165        g =  add_self_loops (g)
62-         test_layer (l, g, rtol =  RTOL_HIGH, test_gpu =  TEST_GPU,
63-                     outsize =  (D_OUT, g. num_nodes))
66+         test_layer (l, g, rtol =  RTOL_HIGH, outsize =  (D_OUT, g. num_nodes))
6467    end 
6568
6669    @testset  " bias=false" begin 
6972    end 
7073end 
7174
72- @testset  " GraphConv" begin 
75+ @testitem  " GraphConv" = [TolSnippet, TestModule] begin 
76+     using  . TestModule
7377    l =  GraphConv (D_IN =>  D_OUT)
7478    for  g in  TEST_GRAPHS
7579        test_layer (l, g, rtol =  RTOL_HIGH, outsize =  (D_OUT, g. num_nodes))
8690    end 
8791end 
8892
89- @testset  " GATConv" begin 
93+ @testitem  " GATConv" = [TolSnippet, TestModule] begin 
94+     using  . TestModule
9095    for  heads in  (1 , 2 ), concat in  (true , false )
9196        l =  GATConv (D_IN =>  D_OUT; heads, concat, dropout= 0 )
9297        for  g in  TEST_GRAPHS
116121    end 
117122end 
118123
119- @testset  " GATv2Conv" begin 
124+ @testitem  " GATv2Conv" = [TolSnippet, TestModule] begin 
125+     using  . TestModule
120126    for  heads in  (1 , 2 ), concat in  (true , false )
121127        l =  GATv2Conv (D_IN =>  D_OUT, tanh; heads, concat, dropout= 0 )
122128        for  g in  TEST_GRAPHS
146152    end 
147153end 
148154
149- @testset  " GatedGraphConv" begin 
155+ @testitem  " GatedGraphConv" = [TolSnippet, TestModule] begin 
156+     using  . TestModule
150157    num_layers =  3 
151158    l =  GatedGraphConv (D_OUT, num_layers)
152159    @test  size (l. weight) ==  (D_OUT, D_OUT, num_layers)
@@ -156,14 +163,16 @@ end
156163    end 
157164end 
158165
159- @testset  " EdgeConv" begin 
166+ @testitem  " EdgeConv" = [TolSnippet, TestModule] begin 
167+     using  . TestModule
160168    l =  EdgeConv (Dense (2  *  D_IN, D_OUT), aggr =  + )
161169    for  g in  TEST_GRAPHS
162170        test_layer (l, g, rtol =  RTOL_HIGH, outsize =  (D_OUT, g. num_nodes))
163171    end 
164172end 
165173
166- @testset  " GINConv" begin 
174+ @testitem  " GINConv" = [TolSnippet, TestModule] begin 
175+     using  . TestModule
167176    nn =  Dense (D_IN, D_OUT)
168177
169178    l =  GINConv (nn, 0.01f0 , aggr =  mean)
174183    @test  ! in (:eps , Flux. trainable (l))
175184end 
176185
177- @testset  " NNConv" begin 
186+ @testitem  " NNConv" = [TolSnippet, TestModule] begin 
187+     using  . TestModule
178188    edim =  10 
179189    nn =  Dense (edim, D_OUT *  D_IN)
180190
185195    end 
186196end 
187197
188- @testset  " SAGEConv" begin 
198+ @testitem  " SAGEConv" = [TolSnippet, TestModule] begin 
199+     using  . TestModule
189200    l =  SAGEConv (D_IN =>  D_OUT)
190201    @test  l. aggr ==  mean
191202
@@ -195,14 +206,17 @@ end
195206    end 
196207end 
197208
198- @testset  " ResGatedGraphConv" begin 
209+ @testitem  " ResGatedGraphConv" = [TolSnippet, TestModule] begin 
210+     using  . TestModule
199211    l =  ResGatedGraphConv (D_IN =>  D_OUT, tanh, bias =  true )
200212    for  g in  TEST_GRAPHS
201213        test_layer (l, g, rtol =  RTOL_HIGH, outsize =  (D_OUT, g. num_nodes))
202214    end 
203215end 
204216
205- @testset  " CGConv" begin 
217+ @testitem  " CGConv" = [TolSnippet, TestModule] begin 
218+     using  . TestModule
219+ 
206220    edim =  10 
207221    l =  CGConv ((D_IN, edim) =>  D_OUT, tanh, residual =  false , bias =  true )
208222    for  g in  TEST_GRAPHS
217231    @test  l1 (g1, g1. ndata. x, nothing ) ==  l1 (g1). ndata. x
218232end 
219233
220- @testset  " AGNNConv" begin 
234+ @testitem  " AGNNConv" = [TolSnippet, TestModule] begin 
235+     using  . TestModule
221236    l =  AGNNConv (trainable= false , add_self_loops= false )
222237    @test  l. β ==  [1.0f0 ]
223238    @test  l. add_self_loops ==  false 
234249    end 
235250end 
236251
237- @testset  " MEGNetConv" begin 
252+ @testitem  " MEGNetConv" = [TolSnippet, TestModule] begin 
253+     using  . TestModule
238254    l =  MEGNetConv (D_IN =>  D_OUT, aggr =  + )
239255    for  g in  TEST_GRAPHS
240256        g =  GNNGraph (g, edata =  rand (Float32, D_IN, g. num_edges))
244260    end 
245261end 
246262
247- @testset  " GMMConv" begin 
263+ @testitem  " GMMConv" = [TolSnippet, TestModule] begin 
264+     using  . TestModule
248265    ein_channel =  10 
249266    K =  5 
250267    l =  GMMConv ((D_IN, ein_channel) =>  D_OUT, K =  K)
254271    end 
255272end 
256273
257- @testset  " SGConv" begin 
274+ @testitem  " SGConv" = [TolSnippet, TestModule] begin 
275+     using  . TestModule
258276    K =  [1 , 2 , 3 ] #  for different number of hops       
259277    for  k in  K
260278        l =  SGConv (D_IN =>  D_OUT, k, add_self_loops =  true )
269287    end 
270288end 
271289
272- @testset  " TAGConv" begin 
290+ @testitem  " TAGConv" = [TolSnippet, TestModule] begin 
291+     using  . TestModule
273292    K =  [1 , 2 , 3 ]
274293    for  k in  K
275294        l =  TAGConv (D_IN =>  D_OUT, k, add_self_loops =  true )
@@ -284,20 +303,25 @@ end
284303    end 
285304end 
286305
287- @testset  " EGNNConv" begin 
288-     hin =  5 
289-     hout =  5 
290-     hidden =  5 
291-     l =  EGNNConv (hin =>  hout, hidden)
292-     g =  rand_graph (10 , 20 , graph_type =  GRAPH_T)
293-     x =  rand (Float32, D_IN, g. num_nodes)
294-     h =  randn (Float32, hin, g. num_nodes)
295-     hnew, xnew =  l (g, h, x)
296-     @test  size (hnew) ==  (hout, g. num_nodes)
297-     @test  size (xnew) ==  (D_IN, g. num_nodes)
306+ @testitem  " EGNNConv" = [TolSnippet, TestModule] begin 
307+     using  . TestModule
308+     # TODO  test gradient
309+     @testset  " EGNNConv $GRAPH_T " for  GRAPH_T in  GRAPH_TYPES
310+         hin =  5 
311+         hout =  5 
312+         hidden =  5 
313+         l =  EGNNConv (hin =>  hout, hidden)
314+         g =  rand_graph (10 , 20 , graph_type =  GRAPH_T)
315+         x =  rand (Float32, D_IN, g. num_nodes)
316+         h =  randn (Float32, hin, g. num_nodes)
317+         hnew, xnew =  l (g, h, x)
318+         @test  size (hnew) ==  (hout, g. num_nodes)
319+         @test  size (xnew) ==  (D_IN, g. num_nodes)
320+     end 
298321end 
299322
300- @testset  " TransformerConv" begin 
323+ @testitem  " TransformerConv" = [TolSnippet, TestModule] begin 
324+     using  . TestModule
301325    ein =  2 
302326    heads =  3 
303327    #  used like in Kool et al., 2019
306330                        batch_norm =  false )
307331    #  batch_norm=false here for tests to pass; true in paper
308332    for  g in  TEST_GRAPHS
309-         g =  GNNGraph (g, ndata =  rand (Float32, D_IN *  heads, g. num_nodes), graph_type  =  GRAPH_T )
333+         g =  GNNGraph (g, ndata =  rand (Float32, D_IN *  heads, g. num_nodes))
310334        test_layer (l, g, rtol =  RTOL_LOW,
311335                    exclude_grad_fields =  [:negative_slope ],
312336                    outsize =  (D_IN *  heads, g. num_nodes))
331355    end 
332356end 
333357
334- @testset  " DConv" begin 
358+ @testitem  " DConv" = [TolSnippet, TestModule] begin 
359+     using  . TestModule
335360    K =  [1 , 2 , 3 ] #  for different number of hops       
336361    for  k in  K
337362        l =  DConv (D_IN =>  D_OUT, k)
0 commit comments