@@ -5,224 +5,3 @@ function ngradient(f, x...)
55    fdm =  central_fdm (5 , 1 )
66    return  FiniteDifferences. grad (fdm, f, x... )
77end 
8- 
9- const  rule_config =  Zygote. ZygoteRuleConfig ()
10- 
11- #  Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed
12- function  FiniteDifferences. to_vec (x:: Integer )
13-     Integer_from_vec (v) =  x
14-     return  Int[x], Integer_from_vec
15- end 
16- 
17- #  Test that forward pass on cpu and gpu are the same. 
18- #  Tests also gradient on cpu and gpu comparing with
19- #  finite difference methods.
20- #  Test gradients with respects to layer weights and to input. 
21- #  If `g` has edge features, it is assumed that the layer can 
22- #  use them in the forward pass as `l(g, x, e)`.
23- #  Test also gradient with respect to `e`. 
24- function  test_layer (l, g:: GNNGraph ; atol =  1e-5 , rtol =  1e-5 ,
25-                     exclude_grad_fields =  [],
26-                     verbose =  false ,
27-                     test_gpu =  TEST_GPU,
28-                     outsize =  nothing ,
29-                     outtype =  :node )
30- 
31-     #  TODO  these give errors, probably some bugs in ChainRulesTestUtils
32-     #  test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false)
33-     #  test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false)
34- 
35-     isnothing (node_features (g)) &&  error (" Plese add node data to the input graph" 
36-     fdm =  central_fdm (5 , 1 )
37- 
38-     x =  node_features (g)
39-     e =  edge_features (g)
40-     use_edge_feat =  ! isnothing (e)
41- 
42-     x64, e64, l64, g64 =  to64 .([x, e, l, g]) #  needed for accurate FiniteDifferences' grad
43-     xgpu, egpu, lgpu, ggpu =  gpu .([x, e, l, g])
44- 
45-     f (l, g:: GNNGraph ) =  l (g)
46-     f (l, g:: GNNGraph , x, e) =  use_edge_feat ?  l (g, x, e) :  l (g, x)
47- 
48-     loss (l, g:: GNNGraph ) = 
49-         if  outtype ==  :node 
50-             sum (node_features (f (l, g)))
51-         elseif  outtype ==  :edge 
52-             sum (edge_features (f (l, g)))
53-         elseif  outtype ==  :graph 
54-             sum (graph_features (f (l, g)))
55-         elseif  outtype ==  :node_edge 
56-             gnew =  f (l, g)
57-             sum (node_features (gnew)) +  sum (edge_features (gnew))
58-         end 
59- 
60-     function  loss (l, g:: GNNGraph , x, e)
61-         y =  f (l, g, x, e)
62-         if  outtype ==  :node_edge 
63-             return  sum (y[1 ]) +  sum (y[2 ])
64-         else 
65-             return  sum (y)
66-         end 
67-     end 
68- 
69-     #  TEST OUTPUT
70-     y =  f (l, g, x, e)
71-     if  outtype ==  :node_edge 
72-         @assert  y isa  Tuple
73-         @test  eltype (y[1 ]) ==  eltype (x)
74-         @test  eltype (y[2 ]) ==  eltype (e)
75-         @test  all (isfinite, y[1 ])
76-         @test  all (isfinite, y[2 ])
77-         if  ! isnothing (outsize)
78-             @test  size (y[1 ]) ==  outsize[1 ]
79-             @test  size (y[2 ]) ==  outsize[2 ]
80-         end 
81-     else 
82-         @test  eltype (y) ==  eltype (x)
83-         @test  all (isfinite, y)
84-         if  ! isnothing (outsize)
85-             @test  size (y) ==  outsize
86-         end 
87-     end 
88- 
89-     #  test same output on different graph formats
90-     gcoo =  GNNGraph (g, graph_type =  :coo )
91-     ycoo =  f (l, gcoo, x, e)
92-     if  outtype ==  :node_edge 
93-         @test  ycoo[1 ] ≈  y[1 ]
94-         @test  ycoo[2 ] ≈  y[2 ]
95-     else 
96-         @test  ycoo ≈  y
97-     end 
98- 
99-     g′ =  f (l, g)
100-     if  outtype ==  :node 
101-         @test  g′. ndata. x ≈  y
102-     elseif  outtype ==  :edge 
103-         @test  g′. edata. e ≈  y
104-     elseif  outtype ==  :graph 
105-         @test  g′. gdata. u ≈  y
106-     elseif  outtype ==  :node_edge 
107-         @test  g′. ndata. x ≈  y[1 ]
108-         @test  g′. edata. e ≈  y[2 ]
109-     else 
110-         @error  " wrong outtype $outtype " 
111-     end 
112-     if  test_gpu
113-         ygpu =  f (lgpu, ggpu, xgpu, egpu)
114-         if  outtype ==  :node_edge 
115-             @test  ygpu[1 ] isa  CuArray
116-             @test  eltype (ygpu[1 ]) ==  eltype (xgpu)
117-             @test  Array (ygpu[1 ]) ≈  y[1 ]
118-             @test  ygpu[2 ] isa  CuArray
119-             @test  eltype (ygpu[2 ]) ==  eltype (xgpu)
120-             @test  Array (ygpu[2 ]) ≈  y[2 ]
121-         else 
122-             @test  ygpu isa  CuArray
123-             @test  eltype (ygpu) ==  eltype (xgpu)
124-             @test  Array (ygpu) ≈  y
125-         end 
126-     end 
127- 
128-     #  TEST x INPUT GRADIENT
129-     x̄ =  gradient (x ->  loss (l, g, x, e), x)[1 ]
130-     x̄_fd =  FiniteDifferences. grad (fdm, x64 ->  loss (l64, g64, x64, e64), x64)[1 ]
131-     @test  eltype (x̄) ==  eltype (x)
132-     @test  x̄≈ x̄_fd atol= atol rtol= rtol
133- 
134-     if  test_gpu
135-         x̄gpu =  gradient (xgpu ->  loss (lgpu, ggpu, xgpu, egpu), xgpu)[1 ]
136-         @test  x̄gpu isa  CuArray
137-         @test  eltype (x̄gpu) ==  eltype (x)
138-         @test  Array (x̄gpu)≈ x̄ atol= atol rtol= rtol
139-     end 
140- 
141-     #  TEST e INPUT GRADIENT
142-     if  e != =  nothing 
143-         verbose &&  println (" Test e gradient cpu" 
144-         ē =  gradient (e ->  loss (l, g, x, e), e)[1 ]
145-         ē_fd =  FiniteDifferences. grad (fdm, e64 ->  loss (l64, g64, x64, e64), e64)[1 ]
146-         @test  eltype (ē) ==  eltype (e)
147-         @test  ē≈ ē_fd atol= atol rtol= rtol
148- 
149-         if  test_gpu
150-             verbose &&  println (" Test e gradient gpu" 
151-             ēgpu =  gradient (egpu ->  loss (lgpu, ggpu, xgpu, egpu), egpu)[1 ]
152-             @test  ēgpu isa  CuArray
153-             @test  eltype (ēgpu) ==  eltype (ē)
154-             @test  Array (ēgpu)≈ ē atol= atol rtol= rtol
155-         end 
156-     end 
157- 
158-     #  TEST LAYER GRADIENT - l(g, x, e) 
159-     l̄ =  gradient (l ->  loss (l, g, x, e), l)[1 ]
160-     l̄_fd =  FiniteDifferences. grad (fdm, l64 ->  loss (l64, g64, x64, e64), l64)[1 ]
161-     test_approx_structs (l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose)
162- 
163-     if  test_gpu
164-         l̄gpu =  gradient (lgpu ->  loss (lgpu, ggpu, xgpu, egpu), lgpu)[1 ]
165-         test_approx_structs (lgpu, l̄gpu, l̄; atol, rtol, exclude_grad_fields, verbose)
166-     end 
167- 
168-     #  TEST LAYER GRADIENT - l(g)
169-     l̄ =  gradient (l ->  loss (l, g), l)[1 ]
170-     test_approx_structs (l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose)
171- 
172-     return  true 
173- end 
174- 
175- function  test_approx_structs (l, l̄, l̄fd; atol =  1e-5 , rtol =  1e-5 ,
176-                              exclude_grad_fields =  [],
177-                              verbose =  false )
178-     l̄ =  l̄ isa  Base. RefValue ?  l̄[] :  l̄           #  Zygote wraps gradient of mutables in RefValue 
179-     l̄fd =  l̄fd isa  Base. RefValue ?  l̄fd[] :  l̄fd           #  Zygote wraps gradient of mutables in RefValue 
180- 
181-     for  f in  fieldnames (typeof (l))
182-         f ∈  exclude_grad_fields &&  continue 
183-         verbose &&  println (" Test gradient of field $f ..." 
184-         x, g, gfd =  getfield (l, f), getfield (l̄, f), getfield (l̄fd, f)
185-         test_approx_structs (x, g, gfd; atol, rtol, exclude_grad_fields, verbose)
186-         verbose &&  println (" ... field $f  done!" 
187-     end 
188-     return  true 
189- end 
190- 
191- function  test_approx_structs (x, g:: Nothing , gfd; atol, rtol, kws... )
192-     #  finite diff gradients has to be zero if present
193-     @test  ! (gfd isa  AbstractArray) ||  isapprox (gfd, fill! (similar (gfd), 0 ); atol, rtol)
194- end 
195- 
196- function  test_approx_structs (x:: Union{AbstractArray, Number} ,
197-                              g:: Union{AbstractArray, Number} , gfd; atol, rtol, kws... )
198-     @test  eltype (g) ==  eltype (x)
199-     if  x isa  CuArray
200-         @test  g isa  CuArray
201-         g =  Array (g)
202-     end 
203-     @test  g≈ gfd atol= atol rtol= rtol
204- end 
205- 
206- """ 
207-     to32(m) 
208- 
209- Convert the `eltype` of model's float parameters to `Float32`. 
210- Preserves integer arrays. 
211- """ 
212- to32 (m) =  _paramtype (Float32, m)
213- 
214- """ 
215-     to64(m) 
216- 
217- Convert the `eltype` of model's float parameters to `Float64`. 
218- Preserves integer arrays. 
219- """ 
220- to64 (m) =  _paramtype (Float64, m)
221- 
222- struct  GNNEltypeAdaptor{T} end 
223- 
224- Adapt. adapt_storage (:: GNNEltypeAdaptor{T} , x:: AbstractArray{<:AbstractFloat} ) where  T =  convert (AbstractArray{T}, x)
225- Adapt. adapt_storage (:: GNNEltypeAdaptor{T} , x:: AbstractArray{<:Integer} ) where  T =  x
226- Adapt. adapt_storage (:: GNNEltypeAdaptor{T} , x:: AbstractArray{<:Number} ) where  T =  convert (AbstractArray{T}, x)
227- 
228- _paramtype (:: Type{T} , m) where  T =  fmap (adapt (GNNEltypeAdaptor {T} ()), m)
0 commit comments