|
1 | 1 | @testset "utils" begin
|
| 2 | + using KernelFunctions: ColVecs |
| 3 | + rng, N, D = MersenneTwister(123456), 10, 2 |
| 4 | + x, X = randn(rng, N), randn(rng, D, N) |
2 | 5 |
|
| 6 | + # Test Matrix data sets. |
| 7 | + @testset "ColVecs" begin |
| 8 | + DX = ColVecs(X) |
| 9 | + @test DX == DX |
| 10 | + @test size(DX) == (N,) |
| 11 | + @test length(DX) == N |
| 12 | + @test getindex(DX, 5) isa AbstractVector |
| 13 | + @test getindex(DX, 5) == X[:, 5] |
| 14 | + @test getindex(DX, 1:2:6) isa ColVecs |
| 15 | + @test getindex(DX, 1:2:6) == ColVecs(X[:, 1:2:6]) |
| 16 | + @test getindex(DX, :) == ColVecs(X) |
| 17 | + @test eachindex(DX) == 1:N |
| 18 | + @test first(DX) == X[:, 1] |
| 19 | + |
| 20 | + let |
| 21 | + @test Zygote.pullback(ColVecs, X)[1] == DX |
| 22 | + DX, back = Zygote.pullback(ColVecs, X) |
| 23 | + @test back((X=ones(size(X)),))[1] == ones(size(X)) |
| 24 | + |
| 25 | + @test Zygote.pullback(DX->DX.X, DX)[1] == X |
| 26 | + X_, back = Zygote.pullback(DX->DX.X, DX) |
| 27 | + @test back(ones(size(X)))[1].X == ones(size(X)) |
| 28 | + end |
| 29 | + end |
3 | 30 | end
|
0 commit comments