@@ -5,224 +5,3 @@ function ngradient(f, x...)
5
5
fdm = central_fdm (5 , 1 )
6
6
return FiniteDifferences. grad (fdm, f, x... )
7
7
end
8
-
9
- const rule_config = Zygote. ZygoteRuleConfig ()
10
-
11
- # Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed
12
- function FiniteDifferences. to_vec (x:: Integer )
13
- Integer_from_vec (v) = x
14
- return Int[x], Integer_from_vec
15
- end
16
-
17
- # Test that forward pass on cpu and gpu are the same.
18
- # Tests also gradient on cpu and gpu comparing with
19
- # finite difference methods.
20
- # Test gradients with respects to layer weights and to input.
21
- # If `g` has edge features, it is assumed that the layer can
22
- # use them in the forward pass as `l(g, x, e)`.
23
- # Test also gradient with respect to `e`.
24
- function test_layer (l, g:: GNNGraph ; atol = 1e-5 , rtol = 1e-5 ,
25
- exclude_grad_fields = [],
26
- verbose = false ,
27
- test_gpu = TEST_GPU,
28
- outsize = nothing ,
29
- outtype = :node )
30
-
31
- # TODO these give errors, probably some bugs in ChainRulesTestUtils
32
- # test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false)
33
- # test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false)
34
-
35
- isnothing (node_features (g)) && error (" Plese add node data to the input graph" )
36
- fdm = central_fdm (5 , 1 )
37
-
38
- x = node_features (g)
39
- e = edge_features (g)
40
- use_edge_feat = ! isnothing (e)
41
-
42
- x64, e64, l64, g64 = to64 .([x, e, l, g]) # needed for accurate FiniteDifferences' grad
43
- xgpu, egpu, lgpu, ggpu = gpu .([x, e, l, g])
44
-
45
- f (l, g:: GNNGraph ) = l (g)
46
- f (l, g:: GNNGraph , x, e) = use_edge_feat ? l (g, x, e) : l (g, x)
47
-
48
- loss (l, g:: GNNGraph ) =
49
- if outtype == :node
50
- sum (node_features (f (l, g)))
51
- elseif outtype == :edge
52
- sum (edge_features (f (l, g)))
53
- elseif outtype == :graph
54
- sum (graph_features (f (l, g)))
55
- elseif outtype == :node_edge
56
- gnew = f (l, g)
57
- sum (node_features (gnew)) + sum (edge_features (gnew))
58
- end
59
-
60
- function loss (l, g:: GNNGraph , x, e)
61
- y = f (l, g, x, e)
62
- if outtype == :node_edge
63
- return sum (y[1 ]) + sum (y[2 ])
64
- else
65
- return sum (y)
66
- end
67
- end
68
-
69
- # TEST OUTPUT
70
- y = f (l, g, x, e)
71
- if outtype == :node_edge
72
- @assert y isa Tuple
73
- @test eltype (y[1 ]) == eltype (x)
74
- @test eltype (y[2 ]) == eltype (e)
75
- @test all (isfinite, y[1 ])
76
- @test all (isfinite, y[2 ])
77
- if ! isnothing (outsize)
78
- @test size (y[1 ]) == outsize[1 ]
79
- @test size (y[2 ]) == outsize[2 ]
80
- end
81
- else
82
- @test eltype (y) == eltype (x)
83
- @test all (isfinite, y)
84
- if ! isnothing (outsize)
85
- @test size (y) == outsize
86
- end
87
- end
88
-
89
- # test same output on different graph formats
90
- gcoo = GNNGraph (g, graph_type = :coo )
91
- ycoo = f (l, gcoo, x, e)
92
- if outtype == :node_edge
93
- @test ycoo[1 ] ≈ y[1 ]
94
- @test ycoo[2 ] ≈ y[2 ]
95
- else
96
- @test ycoo ≈ y
97
- end
98
-
99
- g′ = f (l, g)
100
- if outtype == :node
101
- @test g′. ndata. x ≈ y
102
- elseif outtype == :edge
103
- @test g′. edata. e ≈ y
104
- elseif outtype == :graph
105
- @test g′. gdata. u ≈ y
106
- elseif outtype == :node_edge
107
- @test g′. ndata. x ≈ y[1 ]
108
- @test g′. edata. e ≈ y[2 ]
109
- else
110
- @error " wrong outtype $outtype "
111
- end
112
- if test_gpu
113
- ygpu = f (lgpu, ggpu, xgpu, egpu)
114
- if outtype == :node_edge
115
- @test ygpu[1 ] isa CuArray
116
- @test eltype (ygpu[1 ]) == eltype (xgpu)
117
- @test Array (ygpu[1 ]) ≈ y[1 ]
118
- @test ygpu[2 ] isa CuArray
119
- @test eltype (ygpu[2 ]) == eltype (xgpu)
120
- @test Array (ygpu[2 ]) ≈ y[2 ]
121
- else
122
- @test ygpu isa CuArray
123
- @test eltype (ygpu) == eltype (xgpu)
124
- @test Array (ygpu) ≈ y
125
- end
126
- end
127
-
128
- # TEST x INPUT GRADIENT
129
- x̄ = gradient (x -> loss (l, g, x, e), x)[1 ]
130
- x̄_fd = FiniteDifferences. grad (fdm, x64 -> loss (l64, g64, x64, e64), x64)[1 ]
131
- @test eltype (x̄) == eltype (x)
132
- @test x̄≈ x̄_fd atol= atol rtol= rtol
133
-
134
- if test_gpu
135
- x̄gpu = gradient (xgpu -> loss (lgpu, ggpu, xgpu, egpu), xgpu)[1 ]
136
- @test x̄gpu isa CuArray
137
- @test eltype (x̄gpu) == eltype (x)
138
- @test Array (x̄gpu)≈ x̄ atol= atol rtol= rtol
139
- end
140
-
141
- # TEST e INPUT GRADIENT
142
- if e != = nothing
143
- verbose && println (" Test e gradient cpu" )
144
- ē = gradient (e -> loss (l, g, x, e), e)[1 ]
145
- ē_fd = FiniteDifferences. grad (fdm, e64 -> loss (l64, g64, x64, e64), e64)[1 ]
146
- @test eltype (ē) == eltype (e)
147
- @test ē≈ ē_fd atol= atol rtol= rtol
148
-
149
- if test_gpu
150
- verbose && println (" Test e gradient gpu" )
151
- ēgpu = gradient (egpu -> loss (lgpu, ggpu, xgpu, egpu), egpu)[1 ]
152
- @test ēgpu isa CuArray
153
- @test eltype (ēgpu) == eltype (ē)
154
- @test Array (ēgpu)≈ ē atol= atol rtol= rtol
155
- end
156
- end
157
-
158
- # TEST LAYER GRADIENT - l(g, x, e)
159
- l̄ = gradient (l -> loss (l, g, x, e), l)[1 ]
160
- l̄_fd = FiniteDifferences. grad (fdm, l64 -> loss (l64, g64, x64, e64), l64)[1 ]
161
- test_approx_structs (l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose)
162
-
163
- if test_gpu
164
- l̄gpu = gradient (lgpu -> loss (lgpu, ggpu, xgpu, egpu), lgpu)[1 ]
165
- test_approx_structs (lgpu, l̄gpu, l̄; atol, rtol, exclude_grad_fields, verbose)
166
- end
167
-
168
- # TEST LAYER GRADIENT - l(g)
169
- l̄ = gradient (l -> loss (l, g), l)[1 ]
170
- test_approx_structs (l, l̄, l̄_fd; atol, rtol, exclude_grad_fields, verbose)
171
-
172
- return true
173
- end
174
-
175
- function test_approx_structs (l, l̄, l̄fd; atol = 1e-5 , rtol = 1e-5 ,
176
- exclude_grad_fields = [],
177
- verbose = false )
178
- l̄ = l̄ isa Base. RefValue ? l̄[] : l̄ # Zygote wraps gradient of mutables in RefValue
179
- l̄fd = l̄fd isa Base. RefValue ? l̄fd[] : l̄fd # Zygote wraps gradient of mutables in RefValue
180
-
181
- for f in fieldnames (typeof (l))
182
- f ∈ exclude_grad_fields && continue
183
- verbose && println (" Test gradient of field $f ..." )
184
- x, g, gfd = getfield (l, f), getfield (l̄, f), getfield (l̄fd, f)
185
- test_approx_structs (x, g, gfd; atol, rtol, exclude_grad_fields, verbose)
186
- verbose && println (" ... field $f done!" )
187
- end
188
- return true
189
- end
190
-
191
- function test_approx_structs (x, g:: Nothing , gfd; atol, rtol, kws... )
192
- # finite diff gradients has to be zero if present
193
- @test ! (gfd isa AbstractArray) || isapprox (gfd, fill! (similar (gfd), 0 ); atol, rtol)
194
- end
195
-
196
- function test_approx_structs (x:: Union{AbstractArray, Number} ,
197
- g:: Union{AbstractArray, Number} , gfd; atol, rtol, kws... )
198
- @test eltype (g) == eltype (x)
199
- if x isa CuArray
200
- @test g isa CuArray
201
- g = Array (g)
202
- end
203
- @test g≈ gfd atol= atol rtol= rtol
204
- end
205
-
206
- """
207
- to32(m)
208
-
209
- Convert the `eltype` of model's float parameters to `Float32`.
210
- Preserves integer arrays.
211
- """
212
- to32 (m) = _paramtype (Float32, m)
213
-
214
- """
215
- to64(m)
216
-
217
- Convert the `eltype` of model's float parameters to `Float64`.
218
- Preserves integer arrays.
219
- """
220
- to64 (m) = _paramtype (Float64, m)
221
-
222
- struct GNNEltypeAdaptor{T} end
223
-
224
- Adapt. adapt_storage (:: GNNEltypeAdaptor{T} , x:: AbstractArray{<:AbstractFloat} ) where T = convert (AbstractArray{T}, x)
225
- Adapt. adapt_storage (:: GNNEltypeAdaptor{T} , x:: AbstractArray{<:Integer} ) where T = x
226
- Adapt. adapt_storage (:: GNNEltypeAdaptor{T} , x:: AbstractArray{<:Number} ) where T = convert (AbstractArray{T}, x)
227
-
228
- _paramtype (:: Type{T} , m) where T = fmap (adapt (GNNEltypeAdaptor {T} ()), m)
0 commit comments