@@ -5,224 +5,3 @@ function ngradient(f, x...)
55 fdm = central_fdm (5 , 1 )
66 return FiniteDifferences. grad (fdm, f, x... )
77end
8-
9- const rule_config = Zygote. ZygoteRuleConfig ()
10-
11- # Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed
12- function FiniteDifferences. to_vec (x:: Integer )
13- Integer_from_vec (v) = x
14- return Int[x], Integer_from_vec
15- end
16-
17- # Test that forward pass on cpu and gpu are the same.
18- # Tests also gradient on cpu and gpu comparing with
19- # finite difference methods.
20- # Test gradients with respects to layer weights and to input.
21- # If `g` has edge features, it is assumed that the layer can
22- # use them in the forward pass as `l(g, x, e)`.
23- # Test also gradient with respect to `e`.
24- function test_layer (l, g:: GNNGraph ; atol = 1e-5 , rtol = 1e-5 ,
25- exclude_grad_fields = [],
26- verbose = false ,
27- test_gpu = TEST_GPU,
28- outsize = nothing ,
29- outtype = :node )
30-
31- # TODO these give errors, probably some bugs in ChainRulesTestUtils
32- # test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false)
33- # test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false)
34-
35- isnothing (node_features (g)) && error (" Plese add node data to the input graph" )
36- fdm = central_fdm (5 , 1 )
37-
38- x = node_features (g)
39- e = edge_features (g)
40- use_edge_feat = ! isnothing (e)
41-
42- x64, e64, l64, g64 = to64 .([x, e, l, g]) # needed for accurate FiniteDifferences' grad
43- xgpu, egpu, lgpu, ggpu = gpu .([x, e, l, g])
44-
45- f (l, g:: GNNGraph ) = l (g)
46- f (l, g:: GNNGraph , x, e) = use_edge_feat ? l (g, x, e) : l (g, x)
47-
48- loss (l, g:: GNNGraph ) =
49- if outtype == :node
50- sum (node_features (f (l, g)))
51- elseif outtype == :edge
52- sum (edge_features (f (l, g)))
53- elseif outtype == :graph
54- sum (graph_features (f (l, g)))
55- elseif outtype == :node_edge
56- gnew = f (l, g)
57- sum (node_features (gnew)) + sum (edge_features (gnew))
58- end
59-
60- function loss (l, g:: GNNGraph , x, e)
61- y = f (l, g, x, e)
62- if outtype == :node_edge
63- return sum (y[1 ]) + sum (y[2 ])
64- else
65- return sum (y)
66- end
67- end
68-
69- # TEST OUTPUT
70- y = f (l, g, x, e)
71- if outtype == :node_edge
72- @assert y isa Tuple
73- @test eltype (y[1 ]) == eltype (x)
74- @test eltype (y[2 ]) == eltype (e)
75- @test all (isfinite, y[1 ])
76- @test all (isfinite, y[2 ])
77- if ! isnothing (outsize)
78- @test size (y[1 ]) == outsize[1 ]
79- @test size (y[2 ]) == outsize[2 ]
80- end
81- else
82- @test eltype (y) == eltype (x)
83- @test all (isfinite, y)
84- if ! isnothing (outsize)
85- @test size (y) == outsize
86- end
87- end
88-
89- # test same output on different graph formats
90- gcoo = GNNGraph (g, graph_type = :coo )
91- ycoo = f (l, gcoo, x, e)
92- if outtype == :node_edge
93- @test ycoo[1 ] ≈ y[1 ]
94- @test ycoo[2 ] ≈ y[2 ]
95- else
96- @test ycoo ≈ y
97- end
98-
99- g′ = f (l, g)
100- if outtype == :node
101- @test g′. ndata. x ≈ y
102- elseif outtype == :edge
103- @test g′. edata. e ≈ y
104- elseif outtype == :graph
105- @test g′. gdata. u ≈ y
106- elseif outtype == :node_edge
107- @test g′. ndata. x ≈ y[1 ]
108- @test g′. edata. e ≈ y[2 ]
109- else
110- @error " wrong outtype $outtype "
111- end
112- if test_gpu
113- ygpu = f (lgpu, ggpu, xgpu, egpu)
114- if outtype == :node_edge
115- @test ygpu[1 ] isa CuArray
116- @test eltype (ygpu[1 ]) == eltype (xgpu)
117- @test Array (ygpu[1 ]) ≈ y[1 ]
118- @test ygpu[2 ] isa CuArray
119- @test eltype (ygpu[2 ]) == eltype (xgpu)
120- @test Array (ygpu[2 ]) ≈ y[2 ]
121- else
122- @test ygpu isa CuArray
123- @test eltype (ygpu) == eltype (xgpu)
124- @test Array (ygpu) ≈ y
125- end
126- end
127-
128- # TEST x INPUT GRADIENT
129- x̄ = gradient (x -> loss (l, g, x, e), x)[1 ]
130- x̄_fd = FiniteDifferences. grad (fdm, x64 -> loss (l64, g64, x64, e64), x64)[1 ]
131- @test eltype (x̄) == eltype (x)
132- @test x̄≈ x̄_fd atol= atol rtol= rtol
133-
134- if test_gpu
135- x̄gpu = gradient (xgpu -> loss (lgpu, ggpu, xgpu, egpu), xgpu)[1 ]
136- @test x̄gpu isa CuArray
137- @test eltype (x̄gpu) == eltype (x)
138- @test Array (x̄gpu)≈ x̄ atol= atol rtol= rtol
139- end
140-
141- # TEST e INPUT GRADIENT
142- if e != = nothing
143- verbose && println (" Test e gradient cpu" )
144- ē = gradient (e -> loss (l, g, x, e), e)[1 ]
145- ē_fd = FiniteDifferences. grad (fdm, e64 -> loss (l64, g64, x64, e64), e64)[1 ]
146- @test eltype (ē) == eltype (e)
147- @test ē≈ ē_fd atol= atol rtol= rtol
148-
149- if test_gpu
150- verbose && println (" Test e gradient gpu" )
151- ēgpu = gradient (egpu -> loss (lgpu, ggpu, xgpu, egpu), egpu)[1 ]
152- @test ēgpu isa CuArray
153- @test eltype (ēgpu) == eltype (ē)
154- @test Array (ēgpu)≈ ē atol= atol rtol= rtol
155- end
156- end
157-
158- # TEST LAYER GRADIENT - l(g, x, e)
159- l̄ = gradient (l -> loss (l, g, x, e), l)[1 ]
160- l̄_fd = FiniteDifferences. grad (fdm, l64 -> loss (l64, g64, x64, e64), l64)[1 ]
161- test_approx_structs (l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose)
162-
163- if test_gpu
164- l̄gpu = gradient (lgpu -> loss (lgpu, ggpu, xgpu, egpu), lgpu)[1 ]
165- test_approx_structs (lgpu, l̄gpu, l̄; atol, rtol, exclude_grad_fields, verbose)
166- end
167-
168- # TEST LAYER GRADIENT - l(g)
169- l̄ = gradient (l -> loss (l, g), l)[1 ]
170- test_approx_structs (l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose)
171-
172- return true
173- end
174-
175- function test_approx_structs (l, l̄, l̄fd; atol = 1e-5 , rtol = 1e-5 ,
176- exclude_grad_fields = [],
177- verbose = false )
178- l̄ = l̄ isa Base. RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue
179- l̄fd = l̄fd isa Base. RefValue ? l̄fd[] : l̄fd # Zygote wraps gradient of mutables in RefValue
180-
181- for f in fieldnames (typeof (l))
182- f ∈ exclude_grad_fields && continue
183- verbose && println (" Test gradient of field $f ..." )
184- x, g, gfd = getfield (l, f), getfield (l̄, f), getfield (l̄fd, f)
185- test_approx_structs (x, g, gfd; atol, rtol, exclude_grad_fields, verbose)
186- verbose && println (" ... field $f done!" )
187- end
188- return true
189- end
190-
191- function test_approx_structs (x, g:: Nothing , gfd; atol, rtol, kws... )
192- # finite diff gradients has to be zero if present
193- @test ! (gfd isa AbstractArray) || isapprox (gfd, fill! (similar (gfd), 0 ); atol, rtol)
194- end
195-
196- function test_approx_structs (x:: Union{AbstractArray, Number} ,
197- g:: Union{AbstractArray, Number} , gfd; atol, rtol, kws... )
198- @test eltype (g) == eltype (x)
199- if x isa CuArray
200- @test g isa CuArray
201- g = Array (g)
202- end
203- @test g≈ gfd atol= atol rtol= rtol
204- end
205-
206- """
207- to32(m)
208-
209- Convert the `eltype` of model's float parameters to `Float32`.
210- Preserves integer arrays.
211- """
212- to32 (m) = _paramtype (Float32, m)
213-
214- """
215- to64(m)
216-
217- Convert the `eltype` of model's float parameters to `Float64`.
218- Preserves integer arrays.
219- """
220- to64 (m) = _paramtype (Float64, m)
221-
222- struct GNNEltypeAdaptor{T} end
223-
224- Adapt. adapt_storage (:: GNNEltypeAdaptor{T} , x:: AbstractArray{<:AbstractFloat} ) where T = convert (AbstractArray{T}, x)
225- Adapt. adapt_storage (:: GNNEltypeAdaptor{T} , x:: AbstractArray{<:Integer} ) where T = x
226- Adapt. adapt_storage (:: GNNEltypeAdaptor{T} , x:: AbstractArray{<:Number} ) where T = convert (AbstractArray{T}, x)
227-
228- _paramtype (:: Type{T} , m) where T = fmap (adapt (GNNEltypeAdaptor {T} ()), m)
0 commit comments