Skip to content

Commit 9de994f

Browse files
committed
Simplify test
1 parent 4d788f2 commit 9de994f

File tree

1 file changed

+26
-36
lines changed

1 file changed

+26
-36
lines changed

test/utils.jl

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -61,41 +61,31 @@
6161
end
6262

6363
@testset "topk_nodes" begin
64-
A = [0.0297 0.8307 0.9140 0.6702 0.3346;
65-
0.5901 0.3030 0.9280 0.6893 0.7997;
66-
0.0880 0.6515 0.4451 0.7507 0.5297;
67-
0.5171 0.6379 0.2695 0.8954 0.5197]
68-
B = [0.3168 0.3174 0.5303 0.0804 0.3808;
69-
0.1752 0.9105 0.5692 0.8489 0.0539;
70-
0.1931 0.4954 0.3455 0.3934 0.0857;
71-
0.5065 0.5182 0.5418 0.1520 0.3872]
72-
C = [0.0297 0.0297 0.8307 0.9140 0.6702 0.3346;
73-
0.5901 0.5901 0.3030 0.9280 0.6893 0.7997;
74-
0.0880 0.0880 0.6515 0.4451 0.7507 0.5297;
75-
0.5171 0.5171 0.6379 0.2695 0.8954 0.5197]
76-
g1 = rand_graph(5, 6, ndata = (w = A,))
77-
g2 = rand_graph(5, 6, ndata = (w = B,))
78-
g3 = rand_graph(5, 6, edata = (e = C,))
64+
A = [1.0 5.0 9.0; 2.0 6.0 10.0; 3.0 7.0 11.0; 4.0 8.0 12.0]
65+
B = [0.318907 0.189981 0.991791;
66+
0.547022 0.977349 0.680538;
67+
0.921823 0.35132 0.494715;
68+
0.451793 0.00704976 0.0189275]
69+
g1 = rand_graph(3, 6, ndata = (x = A,))
70+
g2 = rand_graph(3, 6, ndata = B)
71+
72+
#
73+
output1 = topk_nodes(g1, :x, 2)
74+
output2 = topk_nodes(g2, :x, 1, sortby = 2)
75+
76+
@test output1 == [9.0 5.0;
77+
10.0 6.0;
78+
11.0 7.0;
79+
12.0 8.0]
80+
@test output2 == [0.189981;
81+
0.977349;
82+
0.35132;
83+
0.00704976;;]
7984
g = Flux.batch([g1, g2])
80-
output1 = topk_nodes(g1, :w, 3)
81-
output2 = topk_nodes(g1, :w, 3; sortby = 5)
82-
output3 = topk_edges(g3, :e, 3; sortby = 6)
83-
output_batch = topk_nodes(g, :w, 3; sortby = 5)
84-
correctout1 = [0.5901 0.8307 0.9280 0.8954 0.7997;
85-
0.5171 0.6515 0.9140 0.7507 0.5297;
86-
0.0880 0.6379 0.4451 0.6893 0.5197]
87-
correctout2 = [0.5901 0.3030 0.9280 0.6893 0.7997;
88-
0.0880 0.6515 0.4451 0.7507 0.5297;
89-
0.5171 0.6379 0.2695 0.8954 0.5197]
90-
correctout3 = [0.5901 0.5901 0.3030 0.9280 0.6893 0.7997;
91-
0.0880 0.0880 0.6515 0.4451 0.7507 0.5297;
92-
0.5171 0.5171 0.6379 0.2695 0.8954 0.5197]
93-
correctout_batch = [0.5901 0.3030 0.9280 0.6893 0.7997 0.5065 0.5182 0.5418 0.1520 0.3872;
94-
0.0880 0.6515 0.4451 0.7507 0.5297 0.3168 0.3174 0.5303 0.0804 0.3808;
95-
0.5171 0.6379 0.2695 0.8954 0.5197 0.1931 0.4954 0.3455 0.3934 0.0857]
96-
@test output1 == correctout1
97-
@test output2 == correctout2
98-
@test output3 == correctout3
99-
@test output_batch == correctout_batch
85+
output3 = topk_nodes(g, :x, 2; sortby = 4)
86+
@test output3 == [9.0 5.0 0.318907 0.991791;
87+
10.0 6.0 0.547022 0.680538;
88+
11.0 7.0 0.921823 0.494715;
89+
12.0 8.0 0.451793 0.0189275]
10090
end
101-
end
91+
end;

0 commit comments

Comments
 (0)