1
- using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt
1
+ using ChainRulesTestUtils, FiniteDifferences, Zygote, Adapt, CUDA
2
+ CUDA. allowscalar (false )
3
+
4
+ # global GRAPH_T = :coo
5
+ # global TEST_GPU = true
2
6
3
7
const rule_config = Zygote. ZygoteRuleConfig ()
4
8
5
- # Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188
6
- # is fixed
9
+ # Using this until https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188 is fixed
7
10
function FiniteDifferences. to_vec (x:: Integer )
8
11
Integer_from_vec (v) = x
9
12
return Int[x], Integer_from_vec
12
15
function gradtest (l, g:: GNNGraph ; atol= 1e-7 , rtol= 1e-5 ,
13
16
exclude_grad_fields= [],
14
17
broken_grad_fields= [],
15
- verbose = false
18
+ verbose = false ,
19
+ test_gpu = TEST_GPU,
16
20
)
21
+
17
22
# TODO these give errors, probably some bugs in ChainRulesTestUtils
18
23
# test_rrule(rule_config, x -> l(g, x), x; rrule_f=rrule_via_ad, check_inferred=false)
19
24
# test_rrule(rule_config, l -> l(g, x), l; rrule_f=rrule_via_ad, check_inferred=false)
@@ -24,75 +29,120 @@ function gradtest(l, g::GNNGraph; atol=1e-7, rtol=1e-5,
24
29
x = node_features (g)
25
30
e = edge_features (g)
26
31
32
+ x64, e64, l64, g64 = to64 .([x, e, l, g])
33
+ xgpu, egpu, lgpu, ggpu = gpu .([x, e, l, g])
34
+
27
35
f (l, g) = l (g)
28
- f (l, g, x) = isnothing (e) ? l (g, x) : l (g, x, e)
36
+ f (l, g, x:: AbstractArray{Float32} ) = isnothing (e) ? l (g, x) : l (g, x, e)
37
+ f (l, g, x:: AbstractArray{Float64} ) = isnothing (e64) ? l (g, x) : l (g, x, e64)
38
+ f (l, g, x:: CuArray ) = isnothing (e64) ? l (g, x) : l (g, x, egpu)
29
39
30
40
loss (l, g) = sum (node_features (f (l, g)))
31
41
loss (l, g, x) = sum (f (l, g, x))
32
42
loss (l, g, x, e) = sum (l (g, x, e))
33
43
34
- x64, e64, l64, g64 = to64 .([x, e, l, g])
44
+
35
45
# TEST OUTPUT
36
46
y = f (l, g, x)
37
47
@test eltype (y) == eltype (x)
38
48
39
49
g′ = f (l, g)
40
50
@test g′. ndata. x ≈ y
41
51
42
- # TEST X INPUT GRADIENT
52
+ if test_gpu
53
+ ygpu = f (lgpu, ggpu, xgpu)
54
+ @test ygpu isa CuArray
55
+ @test eltype (ygpu) == eltype (xgpu)
56
+ @test Array (ygpu) ≈ y
57
+ end
58
+
59
+
60
+ # TEST x INPUT GRADIENT
43
61
x̄ = gradient (x -> loss (l, g, x), x)[1 ]
44
62
x̄_fd = FiniteDifferences. grad (fdm, x64 -> loss (l64, g64, x64), x64)[1 ]
63
+ @test eltype (x̄) == eltype (x)
45
64
@test x̄ ≈ x̄_fd atol= atol rtol= rtol
46
65
66
+ if test_gpu
67
+ x̄gpu = gradient (xgpu -> loss (lgpu, ggpu, xgpu), xgpu)[1 ]
68
+ @test x̄gpu isa CuArray
69
+ @test eltype (x̄gpu) == eltype (x)
70
+ @test Array (x̄gpu) ≈ x̄ atol= atol rtol= rtol
71
+ end
72
+
73
+
74
+ # TEST e INPUT GRADIENT
47
75
if e != = nothing
48
- # TEST E INPUT GRADIENT
49
76
ē = gradient (e -> loss (l, g, x, e), e)[1 ]
50
77
ē_fd = FiniteDifferences. grad (fdm, e64 -> loss (l64, g64, x64, e64), e64)[1 ]
78
+ @test eltype (ē) == eltype (e)
51
79
@test ē ≈ ē_fd atol= atol rtol= rtol
80
+
81
+ if test_gpu
82
+ ēgpu = gradient (egpu -> loss (lgpu, ggpu, xgpu, egpu), egpu)[1 ]
83
+ @test ēgpu isa CuArray
84
+ @test eltype (ēgpu) == eltype (ē)
85
+ @test Array (ēgpu) ≈ ē atol= atol rtol= rtol
86
+ end
52
87
end
53
88
89
+
54
90
# TEST LAYER GRADIENT - l(g, x)
55
91
l̄ = gradient (l -> loss (l, g, x), l)[1 ]
56
92
l̄_fd = FiniteDifferences. grad (fdm, l64 -> loss (l64, g64, x64), l64)[1 ]
57
93
test_approx_structs (l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
94
+
95
+ if test_gpu
96
+ l̄gpu = gradient (lgpu -> loss (lgpu, ggpu, xgpu), lgpu)[1 ]
97
+ test_approx_structs (lgpu, l̄gpu, l̄; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
98
+ end
99
+
58
100
# TEST LAYER GRADIENT - l(g)
59
101
l̄ = gradient (l -> loss (l, g), l)[1 ]
60
102
l̄_fd = FiniteDifferences. grad (fdm, l64 -> loss (l64, g64), l64)[1 ]
61
103
test_approx_structs (l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
104
+
105
+ return true
62
106
end
63
107
64
- function test_approx_structs (l, l̄, l̄_fd ; atol= 1e-5 , rtol= 1e-5 ,
108
+ function test_approx_structs (l, l̄, l̄2 ; atol= 1e-5 , rtol= 1e-5 ,
65
109
broken_grad_fields= [],
66
110
exclude_grad_fields= [],
67
111
verbose= false )
68
112
69
113
for f in fieldnames (typeof (l))
70
114
f ∈ exclude_grad_fields && continue
71
- f̄, f̄_fd = getfield (l̄, f), getfield (l̄_fd, f)
115
+ f̄, f̄2 = getfield (l̄, f), getfield (l̄2, f)
116
+ x = getfield (l, f)
72
117
if verbose
73
- println ()
74
- @show f getfield (l, f) f̄ f̄_fd
75
- end
118
+ println ()
119
+ @show f x f̄ f̄2
120
+ end
76
121
if isnothing (f̄)
77
122
verbose && println (" A" )
78
- @test ! (f̄_fd isa AbstractArray) || isapprox (f̄_fd , fill! (similar (f̄_fd ), 0 ); atol= atol, rtol= rtol)
123
+ @test ! (f̄2 isa AbstractArray) || isapprox (f̄2 , fill! (similar (f̄2 ), 0 ); atol= atol, rtol= rtol)
79
124
elseif f̄ isa Union{AbstractArray, Number}
80
125
verbose && println (" B" )
81
- @test eltype (f̄) == eltype (getfield (l, f))
126
+ @test eltype (f̄) == eltype (x)
127
+ if x isa CuArray
128
+ @test f̄ isa CuArray
129
+ f̄ = Array (f̄)
130
+ end
82
131
if f ∈ broken_grad_fields
83
- @test_broken f̄ ≈ f̄_fd atol= atol rtol= rtol
132
+ @test_broken f̄ ≈ f̄2 atol= atol rtol= rtol
84
133
else
85
- @test f̄ ≈ f̄_fd atol= atol rtol= rtol
134
+ @test f̄ ≈ f̄2 atol= atol rtol= rtol
86
135
end
87
136
else
88
137
verbose && println (" C" )
89
- test_approx_structs (getfield (l, f), f̄, f̄_fd ; broken_grad_fields)
138
+ test_approx_structs (x, f̄, f̄2 ; broken_grad_fields)
90
139
end
91
140
end
92
141
return true
93
142
end
94
143
95
144
145
+
96
146
"""
97
147
to32(m)
98
148
0 commit comments