@@ -2,7 +2,17 @@ using ChainRulesTestUtils, FiniteDifferences, Zygote
2
2
3
3
const rule_config = Zygote. ZygoteRuleConfig ()
4
4
5
- function gradtest (l, g:: GNNGraph ; atol= 1e-9 , rtol= 1e-5 )
5
+ # Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188
6
+ # is fixed
7
+ function FiniteDifferences. to_vec (x:: Integer )
8
+ Integer_from_vec (v) = x
9
+ return Int[x], Integer_from_vec
10
+ end
11
+
12
+ function gradtest (l, g:: GNNGraph ; atol= 1e-7 , rtol= 1e-5 ,
13
+ exclude_grad_fields= [],
14
+ broken_grad_fields= []
15
+ )
6
16
# TODO these give errors, probably some bugs in ChainRulesTestUtils
7
17
# test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false)
8
18
# test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false)
@@ -27,27 +37,35 @@ function gradtest(l, g::GNNGraph; atol=1e-9, rtol=1e-5)
27
37
# TEST LAYER GRADIENT - l(g, x)
28
38
l̄ = gradient (l -> sum (l (g, x)), l)[1 ]
29
39
l̄_fd = FiniteDifferences. grad (fdm, l -> sum (l (g, x)), l)[1 ]
30
- test_approx_structs (l, l̄, l̄_fd; atol, rtol)
31
-
40
+ test_approx_structs (l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields)
32
41
# TEST LAYER GRADIENT - l(g)
33
42
l̄ = gradient (l -> sum (l (g). ndata. x), l)[1 ]
34
43
l̄_fd = FiniteDifferences. grad (fdm, l -> sum (l (g). ndata. x), l)[1 ]
35
- test_approx_structs (l, l̄, l̄_fd; atol, rtol)
44
+ test_approx_structs (l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields )
36
45
end
37
46
38
- function test_approx_structs (l, l̄, l̄_fd; atol= 1e-9 , rtol= 1e-5 )
47
+ function test_approx_structs (l, l̄, l̄_fd; atol= 1e-5 , rtol= 1e-5 ,
48
+ broken_grad_fields= [],
49
+ exclude_grad_fields= [])
39
50
for f in fieldnames (typeof (l))
51
+ f ∈ exclude_grad_fields && continue
40
52
f̄, f̄_fd = getfield (l̄, f), getfield (l̄_fd, f)
41
53
if isnothing (f̄)
42
- @show f̄ f̄_fd
54
+ # @show f f̄_fd
43
55
@test ! (f̄_fd isa AbstractArray) || isapprox (f̄_fd, fill! (similar (f̄_fd), 0 ); atol= atol, rtol= rtol)
44
56
elseif f̄ isa Union{AbstractArray, Number}
45
57
@test eltype (f̄) == eltype (getfield (l, f))
46
- @test f̄ ≈ f̄_fd atol= atol rtol= rtol
58
+ if f ∈ broken_grad_fields
59
+ @test_broken f̄ ≈ f̄_fd atol= atol rtol= rtol
60
+ else
61
+ # @show f getfield(l, f) f̄ f̄_fd broken_grad_fields
62
+ @test f̄ ≈ f̄_fd atol= atol rtol= rtol
63
+ end
47
64
else
48
- test_approx_structs (getfield (l, f), f̄, f̄_fd)
65
+ test_approx_structs (getfield (l, f), f̄, f̄_fd; broken_grad_fields )
49
66
end
50
67
end
68
+ return true
51
69
end
52
70
53
71
0 commit comments