|
59 | 59 | @test z[:, 1:2] ≈ NNlib.softmax(e2[:, 1:2], dims = 2)
|
60 | 60 | @test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2)
|
61 | 61 | end
|
| 62 | + |
| 63 | + @testset "topk_nodes" begin |
| 64 | + A = [0.0297 0.8307 0.9140 0.6702 0.3346; |
| 65 | + 0.5901 0.3030 0.9280 0.6893 0.7997; |
| 66 | + 0.0880 0.6515 0.4451 0.7507 0.5297; |
| 67 | + 0.5171 0.6379 0.2695 0.8954 0.5197] |
| 68 | + B = [0.3168 0.3174 0.5303 0.0804 0.3808; |
| 69 | + 0.1752 0.9105 0.5692 0.8489 0.0539; |
| 70 | + 0.1931 0.4954 0.3455 0.3934 0.0857; |
| 71 | + 0.5065 0.5182 0.5418 0.1520 0.3872] |
| 72 | + C = [0.0297 0.0297 0.8307 0.9140 0.6702 0.3346; |
| 73 | + 0.5901 0.5901 0.3030 0.9280 0.6893 0.7997; |
| 74 | + 0.0880 0.0880 0.6515 0.4451 0.7507 0.5297; |
| 75 | + 0.5171 0.5171 0.6379 0.2695 0.8954 0.5197] |
| 76 | + g1 = rand_graph(5, 6, ndata = (w = A,)) |
| 77 | + g2 = rand_graph(5, 6, ndata = (w = B,)) |
| 78 | + g3 = rand_graph(5, 6, edata = (e = C,)) |
| 79 | + g = Flux.batch([g1, g2]) |
| 80 | + output1 = topk_nodes(g1, :w, 3) |
| 81 | + output2 = topk_nodes(g1, :w, 3; sortby = 5) |
| 82 | + output3 = topk_edges(g3, :e, 3; sortby = 6) |
| 83 | + output_batch = topk_nodes(g, :w, 3; sortby = 5) |
| 84 | + correctout1 = [0.5901 0.8307 0.9280 0.8954 0.7997; |
| 85 | + 0.5171 0.6515 0.9140 0.7507 0.5297; |
| 86 | + 0.0880 0.6379 0.4451 0.6893 0.5197] |
| 87 | + correctout2 = [0.5901 0.3030 0.9280 0.6893 0.7997; |
| 88 | + 0.0880 0.6515 0.4451 0.7507 0.5297; |
| 89 | + 0.5171 0.6379 0.2695 0.8954 0.5197] |
| 90 | + correctout3 = [0.5901 0.5901 0.3030 0.9280 0.6893 0.7997; |
| 91 | + 0.0880 0.0880 0.6515 0.4451 0.7507 0.5297; |
| 92 | + 0.5171 0.5171 0.6379 0.2695 0.8954 0.5197] |
| 93 | + correctout_batch = [0.5901 0.3030 0.9280 0.6893 0.7997 0.5065 0.5182 0.5418 0.1520 0.3872; |
| 94 | + 0.0880 0.6515 0.4451 0.7507 0.5297 0.3168 0.3174 0.5303 0.0804 0.3808; |
| 95 | + 0.5171 0.6379 0.2695 0.8954 0.5197 0.1931 0.4954 0.3455 0.3934 0.0857] |
| 96 | + @test output1 == correctout1 |
| 97 | + @test output2 == correctout2 |
| 98 | + @test output3 == correctout3 |
| 99 | + @test output_batch == correctout_batch |
| 100 | + end |
62 | 101 | end
|
0 commit comments