|
1 | 1 | @testset "utils" begin
|
2 |
| - using KernelFunctions: ColVecs |
3 |
| - rng, N, D = MersenneTwister(123456), 10, 2 |
| 2 | + using KernelFunctions: vec_of_vecs, ColVecs, RowVecs |
| 3 | + rng, N, D = MersenneTwister(123456), 10, 4 |
4 | 4 | x, X = randn(rng, N), randn(rng, D, N)
|
5 |
| - |
| 5 | + @testset "VecOfVecs" begin |
| 6 | + @test vec_of_vecs(X, obsdim = 2) == ColVecs(X) |
| 7 | + @test vec_of_vecs(X, obsdim = 1) == RowVecs(X) |
| 8 | + end |
6 | 9 | # Test Matrix data sets.
|
7 | 10 | @testset "ColVecs" begin
|
8 | 11 | DX = ColVecs(X)
|
|
22 | 25 | DX, back = Zygote.pullback(ColVecs, X)
|
23 | 26 | @test back((X=ones(size(X)),))[1] == ones(size(X))
|
24 | 27 |
|
| 28 | + @test Zygote.pullback(DX->DX.X, DX)[1] == X |
| 29 | + X_, back = Zygote.pullback(DX->DX.X, DX) |
| 30 | + @test back(ones(size(X)))[1].X == ones(size(X)) |
| 31 | + end |
| 32 | + end |
| 33 | + @testset "RowVecs" begin |
| 34 | + DX = RowVecs(X) |
| 35 | + @test DX == DX |
| 36 | + @test size(DX) == (D,) |
| 37 | + @test length(DX) == D |
| 38 | + @test getindex(DX, 2) isa AbstractVector |
| 39 | + @test getindex(DX, 2) == X[2, :] |
| 40 | + @test getindex(DX, 1:3) isa RowVecs |
| 41 | + @test getindex(DX, 1:3) == RowVecs(X[1:3, :]) |
| 42 | + @test getindex(DX, :) == RowVecs(X) |
| 43 | + @test eachindex(DX) == 1:D |
| 44 | + @test first(DX) == X[1, :] |
| 45 | + |
| 46 | + let |
| 47 | + @test Zygote.pullback(RowVecs, X)[1] == DX |
| 48 | + DX, back = Zygote.pullback(RowVecs, X) |
| 49 | + @test back((X=ones(size(X)),))[1] == ones(size(X)) |
| 50 | + |
25 | 51 | @test Zygote.pullback(DX->DX.X, DX)[1] == X
|
26 | 52 | X_, back = Zygote.pullback(DX->DX.X, DX)
|
27 | 53 | @test back(ones(size(X)))[1].X == ones(size(X))
|
|
0 commit comments