Skip to content

Commit 285dc4b

Browse files
committed
Corrected behaviour for vectors and for derivatives on input data
1 parent 94ba979 commit 285dc4b

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/kernelmatrix.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function kernel(
7070
obsdim::Int = defaultobs
7171
) where {T,T₁<:Real,T₂<:Real}
7272
# TODO Verify dimensions
73-
_kappamatrix!(κ, pairwise(metric(κ),X,Y,dims=obsdim))
73+
kappa(κ, evaluate(metric(κ),x,y))
7474
end
7575

7676
"""
@@ -85,7 +85,9 @@ function kernelmatrix(
8585
obsdim::Int = defaultobs,
8686
symmetrize::Bool = true
8787
) where {T,T₁<:Real}
88-
return kernelmatrix!(Matrix{promote_float(T,T₁)}(undef,size(X,obsdim),size(X,obsdim)),κ,X,obsdim=obsdim,symmetrize=symmetrize)
88+
Tₛ = typeof(zero(eltype(X))*zero(T))
89+
m = size(X,obsdim)
90+
return kernelmatrix!(Matrix{promote_float(T,T₁)}(undef,m,m),κ,X,obsdim=obsdim,symmetrize=symmetrize)
8991
end
9092

9193
"""
@@ -100,7 +102,10 @@ function kernelmatrix(
100102
Y::AbstractMatrix{T₂};
101103
obsdim=defaultobs
102104
) where {T,T₁<:Real,T₂<:Real}
103-
kernelmatrix!(Matrix{promote_float(T,T₁,T₂)}(undef,size(X,obsdim),size(Y,obsdim)),κ,X,Y,obsdim=obsdim)
105+
Tₛ = typeof(zero(eltype(X))*zero(eltype(Y))*zero(T))
106+
m = size(X,obsdim)
107+
n = size(Y,obsdim)
108+
kernelmatrix!(Matrix{Tₛ}(undef,m,n),κ,X,Y,obsdim=obsdim)
104109
end
105110

106111

0 commit comments

Comments
 (0)