@@ -30,66 +30,103 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
30
30
31
31
x = node_features (g)
32
32
e = edge_features (g)
33
+ use_edge_feat = ! isnothing (e)
33
34
34
35
x64, e64, l64, g64 = to64 .([x, e, l, g]) # needed for accurate FiniteDifferences' grad
35
36
xgpu, egpu, lgpu, ggpu = gpu .([x, e, l, g])
36
37
37
38
f (l, g:: GNNGraph ) = l (g)
38
- f (l, g:: GNNGraph , x:: AbstractArray{Float32} ) = isnothing (e) ? l (g, x) : l (g, x, e)
39
- f (l, g:: GNNGraph , x:: AbstractArray{Float64} ) = isnothing (e64) ? l (g, x) : l (g, x, e64)
40
- f (l, g:: GNNGraph , x:: CuArray ) = isnothing (e64) ? l (g, x) : l (g, x, egpu)
39
+ f (l, g:: GNNGraph , x, e) = use_edge_feat ? l (g, x, e) : l (g, x)
41
40
42
41
loss (l, g:: GNNGraph ) = if outtype == :node
43
42
sum (node_features (f (l, g)))
44
43
elseif outtype == :edge
45
44
sum (edge_features (f (l, g)))
46
45
elseif outtype == :graph
47
46
sum (graph_features (f (l, g)))
47
+ elseif outtype == :node_edge
48
+ gnew = f (l, g)
49
+ sum (node_features (gnew)) + sum (edge_features (gnew))
48
50
end
49
51
50
- loss (l, g:: GNNGraph , x) = sum (f (l, g, x))
51
- loss (l, g:: GNNGraph , x, e) = sum (l (g, x, e))
52
+ function loss (l, g:: GNNGraph , x, e)
53
+ y = f (l, g, x, e)
54
+ if outtype == :node_edge
55
+ return sum (y[1 ]) + sum (y[2 ])
56
+ else
57
+ return sum (y)
58
+ end
59
+ end
52
60
53
61
54
62
# TEST OUTPUT
55
- y = f (l, g, x)
56
- @test eltype (y) == eltype (x)
57
- @test all (isfinite, y)
58
- if ! isnothing (outsize)
59
- @test size (y) == outsize
63
+ y = f (l, g, x, e)
64
+ if outtype == :node_edge
65
+ @assert y isa Tuple
66
+ @test eltype (y[1 ]) == eltype (x)
67
+ @test eltype (y[2 ]) == eltype (e)
68
+ @test all (isfinite, y[1 ])
69
+ @test all (isfinite, y[2 ])
70
+ if ! isnothing (outsize)
71
+ @test size (y[1 ]) == outsize[1 ]
72
+ @test size (y[2 ]) == outsize[2 ]
73
+ end
74
+ else
75
+ @test eltype (y) == eltype (x)
76
+ @test all (isfinite, y)
77
+ if ! isnothing (outsize)
78
+ @test size (y) == outsize
79
+ end
60
80
end
61
81
62
82
# test same output on different graph formats
63
83
gcoo = GNNGraph (g, graph_type= :coo )
64
- ycoo = f (l, gcoo, x)
65
- @test ycoo ≈ y
66
-
84
+ ycoo = f (l, gcoo, x, e)
85
+ if outtype == :node_edge
86
+ @test ycoo[1 ] ≈ y[1 ]
87
+ @test ycoo[2 ] ≈ y[2 ]
88
+ else
89
+ @test ycoo ≈ y
90
+ end
91
+
67
92
g′ = f (l, g)
68
93
if outtype == :node
69
94
@test g′. ndata. x ≈ y
70
95
elseif outtype == :edge
71
96
@test g′. edata. e ≈ y
72
97
elseif outtype == :graph
73
98
@test g′. gdata. u ≈ y
99
+ elseif outtype == :node_edge
100
+ @test g′. ndata. x ≈ y[1 ]
101
+ @test g′. edata. e ≈ y[2 ]
74
102
else
75
103
@error " wrong outtype $outtype "
76
104
end
77
105
if test_gpu
78
- ygpu = f (lgpu, ggpu, xgpu)
79
- @test ygpu isa CuArray
80
- @test eltype (ygpu) == eltype (xgpu)
81
- @test Array (ygpu) ≈ y
106
+ ygpu = f (lgpu, ggpu, xgpu, egpu)
107
+ if outtype == :node_edge
108
+ @test ygpu[1 ] isa CuArray
109
+ @test eltype (ygpu[1 ]) == eltype (xgpu)
110
+ @test Array (ygpu[1 ]) ≈ y[1 ]
111
+ @test ygpu[2 ] isa CuArray
112
+ @test eltype (ygpu[2 ]) == eltype (xgpu)
113
+ @test Array (ygpu[2 ]) ≈ y[2 ]
114
+ else
115
+ @test ygpu isa CuArray
116
+ @test eltype (ygpu) == eltype (xgpu)
117
+ @test Array (ygpu) ≈ y
118
+ end
82
119
end
83
120
84
121
85
122
# TEST x INPUT GRADIENT
86
- x̄ = gradient (x -> loss (l, g, x), x)[1 ]
87
- x̄_fd = FiniteDifferences. grad (fdm, x64 -> loss (l64, g64, x64), x64)[1 ]
123
+ x̄ = gradient (x -> loss (l, g, x, e ), x)[1 ]
124
+ x̄_fd = FiniteDifferences. grad (fdm, x64 -> loss (l64, g64, x64, e64 ), x64)[1 ]
88
125
@test eltype (x̄) == eltype (x)
89
126
@test x̄ ≈ x̄_fd atol= atol rtol= rtol
90
127
91
128
if test_gpu
92
- x̄gpu = gradient (xgpu -> loss (lgpu, ggpu, xgpu), xgpu)[1 ]
129
+ x̄gpu = gradient (xgpu -> loss (lgpu, ggpu, xgpu, egpu ), xgpu)[1 ]
93
130
@test x̄gpu isa CuArray
94
131
@test eltype (x̄gpu) == eltype (x)
95
132
@test Array (x̄gpu) ≈ x̄ atol= atol rtol= rtol
@@ -112,13 +149,13 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
112
149
end
113
150
114
151
115
- # TEST LAYER GRADIENT - l(g, x)
116
- l̄ = gradient (l -> loss (l, g, x), l)[1 ]
117
- l̄_fd = FiniteDifferences. grad (fdm, l64 -> loss (l64, g64, x64), l64)[1 ]
152
+ # TEST LAYER GRADIENT - l(g, x, e )
153
+ l̄ = gradient (l -> loss (l, g, x, e ), l)[1 ]
154
+ l̄_fd = FiniteDifferences. grad (fdm, l64 -> loss (l64, g64, x64, e64 ), l64)[1 ]
118
155
test_approx_structs (l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
119
156
120
157
if test_gpu
121
- l̄gpu = gradient (lgpu -> loss (lgpu, ggpu, xgpu), lgpu)[1 ]
158
+ l̄gpu = gradient (lgpu -> loss (lgpu, ggpu, xgpu, egpu ), lgpu)[1 ]
122
159
test_approx_structs (lgpu, l̄gpu, l̄; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
123
160
end
124
161
0 commit comments