Skip to content

Commit 8a6802b

Browse files
authored
optimize test for heterograph (#370)
1 parent 9e0ad4a commit 8a6802b

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

test/layers/heteroconv.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@testset "HeteroGraphConv" begin
22
d, n = 3, 5
33
g = rand_bipartite_heterograph(n, 2*n, 15)
4+
hg = rand_bipartite_heterograph((2,3), 6)
45

56
model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d),
67
(:B,:to,:A) => GraphConv(d => d)])
@@ -93,20 +94,18 @@
9394
end
9495

9596
@testset "CGConv" begin
96-
g = rand_bipartite_heterograph((2,3), 6)
9797
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
9898
layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu),
9999
(:B, :to, :A) => CGConv(4 => 2, relu));
100-
y = layers(g, x);
100+
y = layers(hg, x);
101101
@test size(y.A) == (2,2) && size(y.B) == (2,3)
102102
end
103103

104104
@testset "EdgeConv" begin
105-
g = rand_bipartite_heterograph((2,3), 6)
106105
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
107106
layers = HeteroGraphConv( (:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +),
108107
(:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +));
109-
y = layers(g, x);
108+
y = layers(hg, x);
110109
@test size(y.A) == (2,2) && size(y.B) == (2,3)
111110
end
112111

0 commit comments

Comments
 (0)