|
61 | 61 | end
|
62 | 62 |
|
63 | 63 | @testset "topk_nodes" begin
|
64 |
| - A = [0.0297 0.8307 0.9140 0.6702 0.3346; |
65 |
| - 0.5901 0.3030 0.9280 0.6893 0.7997; |
66 |
| - 0.0880 0.6515 0.4451 0.7507 0.5297; |
67 |
| - 0.5171 0.6379 0.2695 0.8954 0.5197] |
68 |
| - B = [0.3168 0.3174 0.5303 0.0804 0.3808; |
69 |
| - 0.1752 0.9105 0.5692 0.8489 0.0539; |
70 |
| - 0.1931 0.4954 0.3455 0.3934 0.0857; |
71 |
| - 0.5065 0.5182 0.5418 0.1520 0.3872] |
72 |
| - C = [0.0297 0.0297 0.8307 0.9140 0.6702 0.3346; |
73 |
| - 0.5901 0.5901 0.3030 0.9280 0.6893 0.7997; |
74 |
| - 0.0880 0.0880 0.6515 0.4451 0.7507 0.5297; |
75 |
| - 0.5171 0.5171 0.6379 0.2695 0.8954 0.5197] |
76 |
| - g1 = rand_graph(5, 6, ndata = (w = A,)) |
77 |
| - g2 = rand_graph(5, 6, ndata = (w = B,)) |
78 |
| - g3 = rand_graph(5, 6, edata = (e = C,)) |
| 64 | + A = [1.0 5.0 9.0; 2.0 6.0 10.0; 3.0 7.0 11.0; 4.0 8.0 12.0] |
| 65 | + B = [0.318907 0.189981 0.991791; |
| 66 | + 0.547022 0.977349 0.680538; |
| 67 | + 0.921823 0.35132 0.494715; |
| 68 | + 0.451793 0.00704976 0.0189275] |
| 69 | + g1 = rand_graph(3, 6, ndata = (x = A,)) |
| 70 | + g2 = rand_graph(3, 6, ndata = B) |
| 71 | + |
| 72 | + # |
| 73 | + output1 = topk_nodes(g1, :x, 2) |
| 74 | + output2 = topk_nodes(g2, :x, 1, sortby = 2) |
| 75 | + |
| 76 | + @test output1 == [9.0 5.0; |
| 77 | + 10.0 6.0; |
| 78 | + 11.0 7.0; |
| 79 | + 12.0 8.0] |
| 80 | + @test output2 == [0.189981; |
| 81 | + 0.977349; |
| 82 | + 0.35132; |
| 83 | + 0.00704976;;] |
79 | 84 | g = Flux.batch([g1, g2])
|
80 |
| - output1 = topk_nodes(g1, :w, 3) |
81 |
| - output2 = topk_nodes(g1, :w, 3; sortby = 5) |
82 |
| - output3 = topk_edges(g3, :e, 3; sortby = 6) |
83 |
| - output_batch = topk_nodes(g, :w, 3; sortby = 5) |
84 |
| - correctout1 = [0.5901 0.8307 0.9280 0.8954 0.7997; |
85 |
| - 0.5171 0.6515 0.9140 0.7507 0.5297; |
86 |
| - 0.0880 0.6379 0.4451 0.6893 0.5197] |
87 |
| - correctout2 = [0.5901 0.3030 0.9280 0.6893 0.7997; |
88 |
| - 0.0880 0.6515 0.4451 0.7507 0.5297; |
89 |
| - 0.5171 0.6379 0.2695 0.8954 0.5197] |
90 |
| - correctout3 = [0.5901 0.5901 0.3030 0.9280 0.6893 0.7997; |
91 |
| - 0.0880 0.0880 0.6515 0.4451 0.7507 0.5297; |
92 |
| - 0.5171 0.5171 0.6379 0.2695 0.8954 0.5197] |
93 |
| - correctout_batch = [0.5901 0.3030 0.9280 0.6893 0.7997 0.5065 0.5182 0.5418 0.1520 0.3872; |
94 |
| - 0.0880 0.6515 0.4451 0.7507 0.5297 0.3168 0.3174 0.5303 0.0804 0.3808; |
95 |
| - 0.5171 0.6379 0.2695 0.8954 0.5197 0.1931 0.4954 0.3455 0.3934 0.0857] |
96 |
| - @test output1 == correctout1 |
97 |
| - @test output2 == correctout2 |
98 |
| - @test output3 == correctout3 |
99 |
| - @test output_batch == correctout_batch |
| 85 | + output3 = topk_nodes(g, :x, 2; sortby = 4) |
| 86 | + @test output3 == [9.0 5.0 0.318907 0.991791; |
| 87 | + 10.0 6.0 0.547022 0.680538; |
| 88 | + 11.0 7.0 0.921823 0.494715; |
| 89 | + 12.0 8.0 0.451793 0.0189275] |
100 | 90 | end
|
101 |
| -end |
| 91 | +end; |
0 commit comments