Skip to content

Commit 3918d25

Browse files
Remove type instabilities and allow general floating point types (#121)
1 parent 2645c60 commit 3918d25

File tree

15 files changed

+209
-116
lines changed

15 files changed

+209
-116
lines changed

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@ KernelInterpolation.jl follows the interpretation of
55
used in the Julia ecosystem. Notable changes will be documented in this file
66
for human readability.
77

8+
## Changes when updating to v0.3 from v0.2.x
9+
10+
#### Added
11+
12+
- General floating point support ([#121]).
13+
14+
#### Changed
15+
16+
- The functions `random_hypersphere` and `random_hypersphere_boundary` not require a `Tuple` for
17+
the argument `center`. Before, e.g., a `Vector` was allowed ([#121]).
18+
- The element type of `NodeSet`s will now always be converted to a floating point type, i.e., also when
19+
integer values are passed. This is more consistent for an interpolation framework makes many things easier.
20+
A similar approach is also used in the Meshes.jl/CoordRefSystems.jl ecosystem ([#121]).
21+
822
## Changes in the v0.2 lifecycle
923

1024
#### Added

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
2828
KernelInterpolationMeshesExt = "Meshes"
2929

3030
[compat]
31-
DiffEqCallbacks = "3, 4"
31+
DiffEqCallbacks = "4"
3232
ForwardDiff = "0.10.36"
3333
LinearAlgebra = "1"
3434
Meshes = "0.52.1, 0.53"
3535
Printf = "1"
3636
Random = "1"
3737
ReadVTK = "0.2"
3838
RecipesBase = "1.3.4"
39-
Reexport = "1.2"
40-
SciMLBase = "2.56"
39+
Reexport = "1.2.2"
40+
SciMLBase = "2.78"
4141
SimpleUnPack = "1.1"
4242
SpecialFunctions = "2"
43-
StaticArrays = "1.9"
43+
StaticArrays = "1.9.7"
4444
TimerOutputs = "0.5.23"
4545
TrixiBase = "0.1.3"
4646
TypedPolynomials = "0.4.1"

examples/interpolation/interpolation_2d_sphere.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Plots
55
f(x) = x[1] * x[2]
66

77
r = 2.0
8-
center = [-1.0, 2.0]
8+
center = (-1.0, 2.0)
99
n = 40
1010
nodeset = random_hypersphere(n, r, center)
1111
nodeset_boundary = random_hypersphere_boundary(20, r, center)

src/KernelInterpolation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ module KernelInterpolation
1313

1414
using DiffEqCallbacks: PeriodicCallback, PeriodicCallbackAffect
1515
using ForwardDiff: ForwardDiff
16-
using LinearAlgebra: Symmetric, I, norm, tr, muladd, dot, diagind
16+
using LinearAlgebra: Symmetric, I, norm, tr, dot, diagind
1717
using Printf: @sprintf
1818
using Random: Random
1919
using ReadVTK: VTKFile, get_points, get_point_data, get_data

src/basis.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ The basis functions are given by
5959
6060
where `K` is the kernel and `x_j` are the nodes in `centers`.
6161
"""
62-
struct StandardBasis{Kernel} <: AbstractBasis
63-
centers::NodeSet
62+
struct StandardBasis{Dim, RealT, Kernel} <: AbstractBasis
63+
centers::NodeSet{Dim, RealT}
6464
kernel::Kernel
6565
function StandardBasis(centers::NodeSet, kernel::Kernel) where {Kernel}
6666
if dim(kernel) != dim(centers)
6767
throw(DimensionMismatch("The dimension of the kernel and the centers must be the same"))
6868
end
69-
new{typeof(kernel)}(centers, kernel)
69+
new{dim(centers), eltype(centers), typeof(kernel)}(centers, kernel)
7070
end
7171
end
7272

@@ -85,36 +85,35 @@ already includes polynomial augmentation of degree `m` defaulting to `order(kern
8585
which means that the [`kernel_matrix`](@ref) of this basis is the identity matrix making it suitable for interpolation. Since the
8686
basis already includes polynomials no additional polynomial augmentation is needed for interpolation with this basis.
8787
"""
88-
struct LagrangeBasis{Kernel, I <: AbstractInterpolation, Monomials, PolyVars} <:
88+
struct LagrangeBasis{Dim, RealT, Kernel, I <: AbstractInterpolation, Monomials, PolyVars} <:
8989
AbstractBasis
90-
centers::NodeSet
90+
centers::NodeSet{Dim, RealT}
9191
kernel::Kernel
9292
basis_functions::Vector{I}
9393
ps::Monomials
9494
xx::PolyVars
95-
function LagrangeBasis(centers::NodeSet, kernel::Kernel;
96-
m = order(kernel)) where {Kernel}
95+
function LagrangeBasis(centers::NodeSet, kernel::AbstractKernel;
96+
m = order(kernel))
9797
if dim(kernel) != dim(centers)
9898
throw(DimensionMismatch("The dimension of the kernel and the centers must be the same"))
9999
end
100+
RealT = eltype(centers)
100101
K = length(centers)
101-
values = zeros(K)
102-
values[1] = 1.0
102+
values = zeros(RealT, K)
103+
values[1] = one(RealT)
103104
b = interpolate(centers, values, kernel; m = m)
104105
basis_functions = Vector{typeof(b)}(undef, K)
105106
basis_functions[1] = b
106107
for i in 2:K
107-
values[i - 1] = 0.0
108-
values[i] = 1.0
108+
values[i - 1] = zero(RealT)
109+
values[i] = one(RealT)
109110
basis_functions[i] = interpolate(centers, values, kernel; m = m)
110111
end
111112
# All basis functions have same polynomials
112113
ps = first(basis_functions).ps
113114
xx = first(basis_functions).xx
114-
new{typeof(kernel), eltype(basis_functions), typeof(ps), typeof(xx)}(centers,
115-
kernel,
116-
basis_functions,
117-
ps, xx)
115+
new{dim(centers), eltype(centers), typeof(kernel), eltype(basis_functions),
116+
typeof(ps), typeof(xx)}(centers, kernel, basis_functions, ps, xx)
118117
end
119118
end
120119

src/discretization.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ function rhs!(dc, c, semi, t)
176176
boundary_condition.(Ref(t), nodeset_boundary)]
177177
end
178178
# dc = -pde_boundary_matrix * c + rhs_vector
179-
@trixi_timeit timer() "muladd" dc[:]=muladd(pde_boundary_matrix, -c, rhs_vector)
179+
@trixi_timeit timer() "muladd" dc[:]=Base.muladd(pde_boundary_matrix, -c,
180+
rhs_vector)
180181
end
181182
return nothing
182183
end

src/interpolation.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Otherwise, `nodeset` is set to `centers(basis)` or `centers`.
159159
A regularization can be applied to the kernel matrix using the `regularization` argument, cf. [`regularize!`](@ref).
160160
"""
161161
function interpolate(basis::AbstractBasis, values::Vector{RealT},
162-
nodeset::NodeSet{Dim} = centers(basis);
162+
nodeset::NodeSet{Dim, RealT} = centers(basis);
163163
m = order(basis),
164164
regularization = NoRegularization()) where {Dim, RealT}
165165
@assert dim(basis) == Dim
@@ -174,27 +174,29 @@ function interpolate(basis::AbstractBasis, values::Vector{RealT},
174174
else
175175
system_matrix = least_squares_matrix(basis, nodeset, ps, regularization)
176176
end
177-
b = [values; zeros(q)]
177+
b = [values; zeros(RealT, q)]
178178
c = system_matrix \ b
179-
return Interpolation(basis, nodeset, c, system_matrix, ps, xx)
179+
return Interpolation{typeof(basis), dim(basis), eltype(nodeset), typeof(system_matrix),
180+
typeof(ps), typeof(xx)}(basis, nodeset, c, system_matrix, ps, xx)
180181
end
181182
function interpolate(centers::NodeSet{Dim, RealT}, nodeset::NodeSet{Dim, RealT},
182-
values::AbstractVector{RealT}, kernel = GaussKernel{Dim}();
183+
values::AbstractVector{RealT}, kernel = GaussKernel{Dim, RealT}();
183184
kwargs...) where {Dim, RealT}
184185
interpolate(StandardBasis(centers, kernel), values, nodeset; kwargs...)
185186
end
186187

187188
function interpolate(centers::NodeSet{Dim, RealT},
188-
values::AbstractVector{RealT}, kernel = GaussKernel{Dim}();
189+
values::AbstractVector{RealT},
190+
kernel = GaussKernel{Dim}(; shape_parameter = RealT(1.0));
189191
kwargs...) where {Dim, RealT}
190192
interpolate(StandardBasis(centers, kernel), values; kwargs...)
191193
end
192194

193195
# Evaluate interpolant
194196
function (itp::Interpolation)(x)
195-
s = 0
196197
bas = basis(itp)
197198
c = kernel_coefficients(itp)
199+
s = zero(eltype(x))
198200
for j in eachindex(c)
199201
s += c[j] * bas[j](x)
200202
end

src/kernel_matrices.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ function interpolation_matrix(basis::AbstractBasis, ps,
7676
regularize!(k_matrix, regularization)
7777
p_matrix = polynomial_matrix(centers(basis), ps)
7878
system_matrix = [k_matrix p_matrix
79-
p_matrix' zeros(q, q)]
79+
p_matrix' zeros(eltype(k_matrix), q, q)]
8080
return Symmetric(system_matrix)
8181
end
8282

@@ -112,7 +112,7 @@ function least_squares_matrix(basis::AbstractBasis, nodeset::NodeSet, ps,
112112
p_matrix1 = polynomial_matrix(nodeset, ps)
113113
p_matrix2 = polynomial_matrix(centers(basis), ps)
114114
system_matrix = [k_matrix p_matrix1
115-
p_matrix2' zeros(q, q)]
115+
p_matrix2' zeros(eltype(k_matrix), q, q)]
116116
return system_matrix
117117
end
118118

src/kernels/radialsymmetric_kernel.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -278,21 +278,21 @@ function Base.show(io::IO, kernel::WendlandKernel{Dim}) where {Dim}
278278
kernel.shape_parameter, ", d = ", kernel.d, ")")
279279
end
280280

281-
function phi(kernel::WendlandKernel, r::Real)
281+
function phi(kernel::WendlandKernel, r::RealT) where {RealT <: Real}
282282
a_r = kernel.shape_parameter * r
283283
if a_r >= 1
284-
return 0.0
284+
return RealT(0.0)
285285
end
286-
l = floor(kernel.d / 2) + kernel.k + 1
286+
l = floor(Int, kernel.d / 2) + kernel.k + 1
287287
if kernel.k == 0
288288
return (1 - a_r)^l
289289
elseif kernel.k == 1
290290
return (1 - a_r)^(l + 1) * ((l + 1) * a_r + 1)
291291
elseif kernel.k == 2
292-
return 1 / 3 * (1 - a_r)^(l + 2) *
292+
return 1 // 3 * (1 - a_r)^(l + 2) *
293293
((l^2 + 4 * l + 3) * a_r^2 + (3 * l + 6) * a_r + 3)
294294
elseif kernel.k == 3
295-
return 1 / 15 * (1 - a_r)^(l + 3) *
295+
return 1 // 15 * (1 - a_r)^(l + 3) *
296296
((l^3 + 9 * l^2 + 23 * l + 15) * a_r^3 + (6 * l^2 + 36 * l + 45) * a_r^2 +
297297
(15 * l + 45) * a_r + 15)
298298
end
@@ -346,10 +346,10 @@ function Base.show(io::IO, kernel::WuKernel{Dim}) where {Dim}
346346
", shape_parameter = ", kernel.shape_parameter, ")")
347347
end
348348

349-
function phi(kernel::WuKernel, r::Real)
349+
function phi(kernel::WuKernel, r::RealT) where {RealT <: Real}
350350
a_r = kernel.shape_parameter * r
351351
if a_r >= 1
352-
return 0.0
352+
return RealT(0.0)
353353
end
354354
if kernel.l == 0
355355
# k = 0
@@ -358,29 +358,29 @@ function phi(kernel::WuKernel, r::Real)
358358
if kernel.k == 0
359359
return (1 - a_r)^3 * (a_r^2 + 3 * a_r + 1)
360360
elseif kernel.k == 1
361-
return 1 / 2 * (1 - a_r)^2 * (a_r + 2)
361+
return 1 // 2 * (1 - a_r)^2 * (a_r + 2)
362362
end
363363
elseif kernel.l == 2
364364
if kernel.k == 0
365365
return (1 - a_r)^5 * (a_r^4 + 5 * a_r^3 + 9 * a_r^2 + 5 * a_r + 1)
366366
elseif kernel.k == 1
367-
return 1 / 4 * (1 - a_r)^4 * (3 * a_r^3 + 12 * a_r^2 + 16 * a_r + 4)
367+
return 1 // 4 * (1 - a_r)^4 * (3 * a_r^3 + 12 * a_r^2 + 16 * a_r + 4)
368368
elseif kernel.k == 2
369-
return 1 / 8 * (1 - a_r)^3 * (3 * a_r^2 + 9 * a_r + 8)
369+
return 1 // 8 * (1 - a_r)^3 * (3 * a_r^2 + 9 * a_r + 8)
370370
end
371371
elseif kernel.l == 3
372372
if kernel.k == 0
373-
return 1 / 5 * (1 - a_r)^7 *
373+
return 1 // 5 * (1 - a_r)^7 *
374374
(5 * a_r^6 + 35 * a_r^5 + 101 * a_r^4 + 147 * a_r^3 + 101 * a_r^2 +
375375
35 * a_r + 5)
376376
elseif kernel.k == 1
377-
return 1 / 6 * (1 - a_r)^6 *
377+
return 1 // 6 * (1 - a_r)^6 *
378378
(5 * a_r^5 + 30 * a_r^4 + 72 * a_r^3 + 82 * a_r^2 + 36 * a_r + 6)
379379
elseif kernel.k == 2
380-
return 1 / 8 * (1 - a_r)^5 *
380+
return 1 // 8 * (1 - a_r)^5 *
381381
(5 * a_r^4 + 25 * a_r^3 + 48 * a_r^2 + 40 * a_r + 8)
382382
elseif kernel.k == 3
383-
return 1 / 16 * (1 - a_r)^4 * (5 * a_r^3 + 20 * a_r^2 + 29 * a_r + 16)
383+
return 1 // 16 * (1 - a_r)^4 * (5 * a_r^3 + 20 * a_r^2 + 29 * a_r + 16)
384384
end
385385
end
386386
end
@@ -544,8 +544,8 @@ function Base.show(io::IO, kernel::Matern32Kernel{Dim}) where {Dim}
544544
kernel.shape_parameter, ")")
545545
end
546546

547-
function phi(kernel::Matern32Kernel, r::Real)
548-
y = sqrt(3) * kernel.shape_parameter * r
547+
function phi(kernel::Matern32Kernel, r::RealT) where {RealT <: Real}
548+
y = RealT(sqrt(3)) * kernel.shape_parameter * r
549549
return (1 + y) * exp(-y)
550550
end
551551
order(::Matern32Kernel) = 0
@@ -580,9 +580,9 @@ function Base.show(io::IO, kernel::Matern52Kernel{Dim}) where {Dim}
580580
print(io, "Matern52Kernel{", Dim, "}(shape_parameter = ", kernel.shape_parameter, ")")
581581
end
582582

583-
function phi(kernel::Matern52Kernel, r::Real)
584-
y = sqrt(5) * kernel.shape_parameter * r
585-
return 1 / 3 * (3 + 3 * y + y^2) * exp(-y)
583+
function phi(kernel::Matern52Kernel, r::RealT) where {RealT <: Real}
584+
y = RealT(sqrt(5)) * kernel.shape_parameter * r
585+
return 1 // 3 * (3 + 3 * y + y^2) * exp(-y)
586586
end
587587
order(::Matern52Kernel) = 0
588588

@@ -616,8 +616,8 @@ function Base.show(io::IO, kernel::Matern72Kernel{Dim}) where {Dim}
616616
print(io, "Matern72Kernel{", Dim, "}(shape_parameter = ", kernel.shape_parameter, ")")
617617
end
618618

619-
function phi(kernel::Matern72Kernel, r::Real)
620-
y = sqrt(7) * kernel.shape_parameter * r
619+
function phi(kernel::Matern72Kernel, r::RealT) where {RealT <: Real}
620+
y = RealT(sqrt(7)) * kernel.shape_parameter * r
621621
return (1 + y + 6 * y^2 / 15 + y^3 / 15) * exp(-y)
622622
end
623623
order(::Matern72Kernel) = 0

src/kernels/special_kernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ end
5757

5858
function (kernel::ProductKernel)(x, y)
5959
@assert length(x) == length(y)
60-
res = 1.0
60+
res = eltype(x)(1.0)
6161
for k in kernel.kernels
6262
res *= k(x, y)
6363
end

0 commit comments

Comments
 (0)