11@testmodule  TestModuleGNNlib begin 
22
3+ using  Pkg
4+ 
5+ # ## GPU backends settings ############
6+ #  tried to put this in __init__ but is not executed for some reason
7+ 
8+ # # Uncomment below to change the default test settings
9+ ENV [" GNN_TEST_CUDA" =  " true" 
10+ #  ENV["GNN_TEST_AMDGPU"] = "true"
11+ #  ENV["GNN_TEST_Metal"] = "true"
12+ 
13+ to_test (backend) =  get (ENV , " GNN_TEST_$(backend) " " false" ==  " true" 
14+ has_dependecies (pkgs) =  all (pkg ->  haskey (Pkg. project (). dependencies, pkg), pkgs)
15+ deps_dict =  Dict (:CUDA  =>  [" CUDA" " cuDNN" :AMDGPU  =>  [" AMDGPU" :Metal  =>  [" Metal" 
16+ 
17+ for  (backend, deps) in  deps_dict
18+     if  to_test (backend)
19+         if  ! has_dependecies (deps)
20+             Pkg. add (deps)
21+         end 
22+         @eval  using  $ backend
23+         @eval  $ backend. allowscalar (false )
24+     end 
25+ end 
26+ # #####################################
27+ 
328import  Reexport:  @reexport 
429
530@reexport  using  GNNlib
@@ -8,5 +33,144 @@ import Reexport: @reexport
833@reexport  using  MLUtils
934@reexport  using  SparseArrays
1035@reexport  using  Test, Random, Statistics
36+ @reexport  using  MLDataDevices
37+ using  Functors:  fmapstructure_with_path
38+ using  Graphs
39+ using  ChainRulesTestUtils, FiniteDifferences
40+ using  Zygote
41+ using  SparseArrays
42+ 
43+ #  from this module
44+ export  D_IN, D_OUT, GRAPH_TYPES, TEST_GRAPHS,
45+        test_gradients, finitediff_withgradient, 
46+        check_equal_leaves
47+ 
48+ 
49+ const  D_IN =  3 
50+ const  D_OUT =  5 
51+ 
52+ function  finitediff_withgradient (f, x... )
53+     y =  f (x... )
54+     #  We set a range to avoid domain errors
55+     fdm =  FiniteDifferences. central_fdm (5 , 1 , max_range= 1e-2 )
56+     return  y, FiniteDifferences. grad (fdm, f, x... )
57+ end 
58+ 
59+ function  check_equal_leaves (a, b; rtol= 1e-4 , atol= 1e-4 )
60+     fmapstructure_with_path (a, b) do  kp, x, y
61+         if  x isa  AbstractArray
62+             #  @show kp
63+             @test  x ≈  y rtol= rtol atol= atol
64+         #  elseif x isa Number
65+         #      @show kp
66+         #      @test x ≈ y rtol=rtol atol=atol
67+         end 
68+     end 
69+ end 
70+ 
71+ function  test_gradients (
72+             f,
73+             graph:: GNNGraph , 
74+             xs... ;
75+             rtol= 1e-5 , atol= 1e-5 ,
76+             test_gpu =  false ,
77+             test_grad_f =  true ,
78+             test_grad_x =  true ,
79+             compare_finite_diff =  true ,
80+             loss =  (f, g, xs... ) ->  mean (f (g, xs... )),
81+             )
82+ 
83+     if  ! test_gpu &&  ! compare_finite_diff
84+         error (" You should either compare finite diff vs CPU AD \
85+                or CPU AD vs GPU AD."  )
86+     end 
87+ 
88+     # # Let's make sure first that the forward pass works.
89+     l =  loss (f, graph, xs... )
90+     @test  l isa  Number
91+     if  test_gpu
92+         gpu_dev =  gpu_device (force= true )
93+         cpu_dev =  cpu_device ()
94+         graph_gpu =  graph |>  gpu_dev
95+         xs_gpu =  xs |>  gpu_dev
96+         f_gpu =  f |>  gpu_dev
97+         l_gpu =  loss (f_gpu, graph_gpu, xs_gpu... )
98+         @test  l_gpu isa  Number
99+     end 
100+ 
101+     if  test_grad_x
102+         #  Zygote gradient with respect to input.
103+         y, g =  Zygote. withgradient ((xs... ) ->  loss (f, graph, xs... ), xs... )
104+         
105+         if  compare_finite_diff
106+             #  Cast to Float64 to avoid precision issues.
107+             f64 =  f |>  Flux. f64
108+             xs64 =  xs .| >  Flux. f64
109+             y_fd, g_fd =  finitediff_withgradient ((xs... ) ->  loss (f64, graph, xs... ), xs64... )
110+             @test  y ≈  y_fd rtol= rtol atol= atol
111+             check_equal_leaves (g, g_fd; rtol, atol)
112+         end 
113+ 
114+         if  test_gpu
115+             #  Zygote gradient with respect to input on GPU.
116+             y_gpu, g_gpu =  Zygote. withgradient ((xs... ) ->  loss (f_gpu, graph_gpu, xs... ), xs_gpu... )
117+             @test  get_device (g_gpu) ==  get_device (xs_gpu)
118+             @test  y_gpu ≈  y rtol= rtol atol= atol
119+             check_equal_leaves (g_gpu |>  cpu_dev, g; rtol, atol)
120+         end 
121+     end 
122+ 
123+     if  test_grad_f
124+         #  Zygote gradient with respect to f.
125+         y, g =  Zygote. withgradient (f ->  loss (f, graph, xs... ), f)
126+ 
127+         if  compare_finite_diff
128+             #  Cast to Float64 to avoid precision issues.
129+             f64 =  f |>  Flux. f64
130+             ps, re =  Flux. destructure (f64)
131+             y_fd, g_fd =  finitediff_withgradient (ps ->  loss (re (ps),graph, xs... ), ps)
132+             g_fd =  (re (g_fd[1 ]),)
133+             @test  y ≈  y_fd rtol= rtol atol= atol
134+             check_equal_leaves (g, g_fd; rtol, atol)
135+         end 
136+ 
137+         if  test_gpu
138+             #  Zygote gradient with respect to f on GPU.
139+             y_gpu, g_gpu =  Zygote. withgradient (f ->  loss (f,graph_gpu, xs_gpu... ), f_gpu)
140+             #  @test get_device(g_gpu) == get_device(xs_gpu)
141+             @test  y_gpu ≈  y rtol= rtol atol= atol
142+             check_equal_leaves (g_gpu |>  cpu_dev, g; rtol, atol)
143+         end 
144+     end 
145+     return  true 
146+ end 
147+ 
148+ 
149+ function  generate_test_graphs (graph_type)
150+     adj1 =  [0  1  0  1 
151+             1  0  1  0 
152+             0  1  0  1 
153+             1  0  1  0 ]
154+ 
155+     g1 =  GNNGraph (adj1,
156+                     ndata =  rand (Float32, D_IN, 4 );
157+                     graph_type)
158+ 
159+     adj_single_vertex =  [0  0  0  1 
160+                             0  0  0  0 
161+                             0  0  0  1 
162+                             1  0  1  0 ]
163+ 
164+     g_single_vertex =  GNNGraph (adj_single_vertex,
165+                                 ndata =  rand (Float32, D_IN, 4 );
166+                                 graph_type)
167+ 
168+     return  (g1, g_single_vertex)
169+ end 
170+ 
171+ GRAPH_TYPES =  [:coo , :dense , :sparse ]
172+ TEST_GRAPHS =  [generate_test_graphs (:coo )... ,
173+                generate_test_graphs (:dense )... ,
174+                generate_test_graphs (:sparse )... ]
11175
12176end  #  module
0 commit comments