Skip to content

Commit 7ccfd83

Browse files
clenaup
1 parent eed4d5b commit 7ccfd83

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Carlo Lucibello and contributors"]
44
version = "0.1.1"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ update_node(l::GatedGraphConv, m, x) = m
340340
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
341341
@non_differentiable fill!(x...)
342342

343-
function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
343+
function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
344344
check_num_nodes(g, H)
345345
m, n = size(H)
346346
@assert (m <= l.out_ch) "number of input features must less or equals to output features."

test/layers/conv.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,11 @@
4646
@test size(l.bias) == (out_channel,)
4747
@test l.k == k
4848
for g in test_graphs
49-
if g === g_single_vertex && GRAPH_T == :dense
50-
@test_broken test_layer(l, g, rtol=1e-5, broken_grad_fields=[:weight], test_gpu=false)
51-
else
52-
test_layer(l, g, rtol=1e-5, broken_grad_fields=[:weight], test_gpu=false)
53-
if TEST_GPU
54-
@test_broken test_layer(l, g, rtol=1e-5, broken_grad_fields=[:weight], test_gpu=true)
55-
end
56-
end
49+
g = add_self_loops(g)
50+
test_layer(l, g, rtol=1e-5, test_gpu=false, outsize=(out_channel, g.num_nodes))
51+
if TEST_GPU
52+
@test_broken test_layer(l, g, rtol=1e-5, test_gpu=true, outsize=(out_channel, g.num_nodes))
53+
end
5754
end
5855

5956
@testset "bias=false" begin
@@ -81,10 +78,10 @@
8178

8279
@testset "GATConv" begin
8380

84-
for heads in (1, 3), concat in (true, false)
81+
for heads in (1, 2), concat in (true, false)
8582
l = GATConv(in_channel => out_channel; heads, concat)
8683
for g in test_graphs
87-
test_layer(l, g, rtol=1e-4,
84+
test_layer(l, g, rtol=1e-4,
8885
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
8986
end
9087
end

0 commit comments

Comments
 (0)