Skip to content

Commit fb37557

Browse files
authored
Merge pull request #114 from theogf/test_AD
Series of tests for AD
2 parents 3b0cf61 + e94973e commit fb37557

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+375
-229
lines changed

.github/workflows/CompatHelper.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ jobs:
1616
- name: CompatHelper.main()
1717
env:
1818
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
19-
run: julia -e 'using CompatHelper; CompatHelper.main()'
19+
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "test"])'

Project.toml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,3 @@ StatsBase = "0.32, 0.33"
2222
StatsFuns = "0.8, 0.9"
2323
ZygoteRules = "0.2"
2424
julia = "1.3"
25-
26-
[extras]
27-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
28-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
29-
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
30-
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
31-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
32-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
33-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
34-
35-
[targets]
36-
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"]

src/KernelFunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel
3434
using Compat
3535
using Requires
3636
using Distances, LinearAlgebra
37-
using SpecialFunctions: logabsgamma, besselk
38-
using ZygoteRules: @adjoint
37+
using SpecialFunctions: logabsgamma, besselk, polygamma
38+
using ZygoteRules: @adjoint, pullback
3939
using StatsFuns: logtwo
4040
using InteractiveUtils: subtypes
4141
using StatsBase

src/basekernels/matern.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@ end
1717

1818
@inline function kappa::MaternKernel, d::Real)
1919
ν = first.ν)
20-
iszero(d) ? one(d) :
21-
exp(
22-
(one(d) - ν) * logtwo - logabsgamma(ν)[1] +
23-
ν * log(sqrt(2ν) * d) +
24-
log(besselk(ν, sqrt(2ν) * d))
25-
)
20+
iszero(d) ? one(d) : _matern(ν, d)
21+
end
22+
23+
function _matern::Real, d::Real)
24+
exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(sqrt(2ν) * d) + log(besselk(ν, sqrt(2ν) * d)))
2625
end
2726

2827
metric(::MaternKernel) = Euclidean()

src/distances/delta.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
struct Delta <: Distances.PreMetric
22
end
33

4-
@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T}
4+
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) where {T}
55
@boundscheck if length(a) != length(b)
66
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
77
end
88
return a == b
99
end
1010

11+
Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
12+
1113
@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
12-
@inline (dist::Delta)(a::Number,b::Number) = a == b
14+
@inline (dist::Delta)(a::Number, b::Number) = a == b

src/distances/dotproduct.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
struct DotProduct <: Distances.PreMetric end
22
# struct DotProduct <: Distances.UnionSemiMetric end
33

4-
@inline function Distances._evaluate(::DotProduct, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
4+
@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
55
@boundscheck if length(a) != length(b)
66
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
77
end
88
return dot(a,b)
99
end
1010

11+
Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
12+
1113
@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b
1214
@inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b)
1315
@inline (dist::DotProduct)(a::Number,b::Number) = a * b

src/distances/sinus.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ Distances.parameters(d::Sinus) = d.r
88
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
99
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r))
1010

11-
@inline function Distances._evaluate(d::Sinus, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
11+
Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb)
12+
13+
@inline function Distances._evaluate(d::Sinus, a::AbstractVector, b::AbstractVector) where {T}
1214
@boundscheck if (length(a) != length(b)) || length(a) != length(d.r)
1315
throw(DimensionMismatch("Dimensions of the inputs are not matching : a = $(length(a)), b = $(length(b)), r = $(length(d.r))"))
1416
end

src/transform/ardtransform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ dim(t::ARDTransform) = length(t.v)
2424
(t::ARDTransform)(x::Real) = first(t.v) * x
2525
(t::ARDTransform)(x) = t.v .* x
2626

27-
Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
28-
Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
29-
Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
27+
_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
28+
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
29+
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
3030

3131
Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)
3232

src/transform/chaintransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor
2727

2828
(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)
2929

30-
function Base.map(t::ChainTransform, x::AbstractVector)
30+
function _map(t::ChainTransform, x::AbstractVector)
3131
return foldl((x, t) -> map(t, x), t.transforms; init=x)
3232
end
3333

src/transform/functiontransform.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ end
1515

1616
(t::FunctionTransform)(x) = t.f(x)
1717

18-
Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
19-
Base.map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
20-
Base.map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
18+
_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
19+
_map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
20+
_map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
2121

2222
duplicate(t::FunctionTransform,f) = FunctionTransform(f)
2323

0 commit comments

Comments
 (0)