@@ -34,14 +34,21 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
34
34
x64, e64, l64, g64 = to64 .([x, e, l, g]) # needed for accurate FiniteDifferences' grad
35
35
xgpu, egpu, lgpu, ggpu = gpu .([x, e, l, g])
36
36
37
- f (l, g) = l (g)
38
- f (l, g, x:: AbstractArray{Float32} ) = isnothing (e) ? l (g, x) : l (g, x, e)
39
- f (l, g, x:: AbstractArray{Float64} ) = isnothing (e64) ? l (g, x) : l (g, x, e64)
40
- f (l, g, x:: CuArray ) = isnothing (e64) ? l (g, x) : l (g, x, egpu)
37
+ f (l, g:: GNNGraph ) = l (g)
38
+ f (l, g:: GNNGraph , x:: AbstractArray{Float32} ) = isnothing (e) ? l (g, x) : l (g, x, e)
39
+ f (l, g:: GNNGraph , x:: AbstractArray{Float64} ) = isnothing (e64) ? l (g, x) : l (g, x, e64)
40
+ f (l, g:: GNNGraph , x:: CuArray ) = isnothing (e64) ? l (g, x) : l (g, x, egpu)
41
41
42
- loss (l, g) = sum (node_features (f (l, g)))
43
- loss (l, g, x) = sum (f (l, g, x))
44
- loss (l, g, x, e) = sum (l (g, x, e))
42
+ loss (l, g:: GNNGraph ) = if outtype == :node
43
+ sum (node_features (f (l, g)))
44
+ elseif outtype == :edge
45
+ sum (edge_features (f (l, g)))
46
+ elseif outtype == :graph
47
+ sum (graph_features (f (l, g)))
48
+ end
49
+
50
+ loss (l, g:: GNNGraph , x) = sum (f (l, g, x))
51
+ loss (l, g:: GNNGraph , x, e) = sum (l (g, x, e))
45
52
46
53
47
54
# TEST OUTPUT
@@ -117,7 +124,6 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
117
124
118
125
# TEST LAYER GRADIENT - l(g)
119
126
l̄ = gradient (l -> loss (l, g), l)[1 ]
120
- l̄_fd = FiniteDifferences. grad (fdm, l64 -> loss (l64, g64), l64)[1 ]
121
127
test_approx_structs (l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
122
128
123
129
return true
0 commit comments