Skip to content

Commit 92e7314

Browse files
committed
Add test
1 parent f72cacb commit 92e7314

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

test/utils.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,43 @@
5959
@test z[:, 1:2] NNlib.softmax(e2[:, 1:2], dims = 2)
6060
@test z[:, 3:4] NNlib.softmax(e2[:, 3:4], dims = 2)
6161
end
62+
63+
@testset "topk_nodes" begin
64+
A = [0.0297 0.8307 0.9140 0.6702 0.3346;
65+
0.5901 0.3030 0.9280 0.6893 0.7997;
66+
0.0880 0.6515 0.4451 0.7507 0.5297;
67+
0.5171 0.6379 0.2695 0.8954 0.5197]
68+
B = [0.3168 0.3174 0.5303 0.0804 0.3808;
69+
0.1752 0.9105 0.5692 0.8489 0.0539;
70+
0.1931 0.4954 0.3455 0.3934 0.0857;
71+
0.5065 0.5182 0.5418 0.1520 0.3872]
72+
C = [0.0297 0.0297 0.8307 0.9140 0.6702 0.3346;
73+
0.5901 0.5901 0.3030 0.9280 0.6893 0.7997;
74+
0.0880 0.0880 0.6515 0.4451 0.7507 0.5297;
75+
0.5171 0.5171 0.6379 0.2695 0.8954 0.5197]
76+
g1 = rand_graph(5, 6, ndata = (w = A,))
77+
g2 = rand_graph(5, 6, ndata = (w = B,))
78+
g3 = rand_graph(5, 6, edata = (e = C,))
79+
g = Flux.batch([g1, g2])
80+
output1 = topk_nodes(g1, :w, 3)
81+
output2 = topk_nodes(g1, :w, 3; sortby = 5)
82+
output3 = topk_edges(g3, :e, 3; sortby = 6)
83+
output_batch = topk_nodes(g, :w, 3; sortby = 5)
84+
correctout1 = [0.5901 0.8307 0.9280 0.8954 0.7997;
85+
0.5171 0.6515 0.9140 0.7507 0.5297;
86+
0.0880 0.6379 0.4451 0.6893 0.5197]
87+
correctout2 = [0.5901 0.3030 0.9280 0.6893 0.7997;
88+
0.0880 0.6515 0.4451 0.7507 0.5297;
89+
0.5171 0.6379 0.2695 0.8954 0.5197]
90+
correctout3 = [0.5901 0.5901 0.3030 0.9280 0.6893 0.7997;
91+
0.0880 0.0880 0.6515 0.4451 0.7507 0.5297;
92+
0.5171 0.5171 0.6379 0.2695 0.8954 0.5197]
93+
correctout_batch = [0.5901 0.3030 0.9280 0.6893 0.7997 0.5065 0.5182 0.5418 0.1520 0.3872;
94+
0.0880 0.6515 0.4451 0.7507 0.5297 0.3168 0.3174 0.5303 0.0804 0.3808;
95+
0.5171 0.6379 0.2695 0.8954 0.5197 0.1931 0.4954 0.3455 0.3934 0.0857]
96+
@test output1 == correctout1
97+
@test output2 == correctout2
98+
@test output3 == correctout3
99+
@test output_batch == correctout_batch
100+
end
62101
end

0 commit comments

Comments
 (0)