Skip to content

Commit be06179

Browse files
committed
Fix type inference and performance problems of munge_data
1 parent e409e9c commit be06179

File tree

3 files changed

+54
-34
lines changed

3 files changed

+54
-34
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ EnumX = "1.0.4"
3232
FindFirstFunctions = "1.3"
3333
FiniteDifferences = "0.12.31"
3434
ForwardDiff = "0.10.36"
35+
JET = "0.9.17"
3536
LinearAlgebra = "1.10"
3637
Optim = "1.6"
3738
PrettyTables = "2"
@@ -53,6 +54,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
5354
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
5455
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
5556
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
57+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
5658
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
5759
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
5860
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
@@ -64,4 +66,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
6466
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6567

6668
[targets]
67-
test = ["Aqua", "BenchmarkTools", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"]
69+
test = ["Aqua", "BenchmarkTools", "JET", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"]

src/interpolation_utils.jl

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -104,51 +104,49 @@ function quadratic_spline_params(t::AbstractVector, sc::AbstractVector)
104104
end
105105

106106
# helper function for data manipulation
107-
function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real})
108-
return u, t
109-
end
110-
111107
function munge_data(u::AbstractVector, t::AbstractVector)
112-
Tu = Base.nonmissingtype(eltype(u))
113-
Tt = Base.nonmissingtype(eltype(t))
114-
@assert length(t) == length(u)
115-
non_missing_indices = collect(
116-
i for i in 1:length(t)
117-
if !ismissing(u[i]) && !ismissing(t[i])
118-
)
108+
Tu = nonmissingtype(eltype(u))
109+
Tt = nonmissingtype(eltype(t))
110+
if Tu === eltype(u) && Tt === eltype(t)
111+
return u, t
112+
end
119113

120-
u = Tu.([u[i] for i in non_missing_indices])
121-
t = Tt.([t[i] for i in non_missing_indices])
114+
@assert length(t) == length(u)
115+
non_missing_mask = map((ui, ti) -> !ismissing(ui) && !ismissing(ti), u, t)
116+
u = convert(AbstractVector{Tu}, u[non_missing_mask])
117+
t = convert(AbstractVector{Tt}, t[non_missing_mask])
122118

123119
return u, t
124120
end
125121

126-
function munge_data(U::StridedMatrix, t::AbstractVector)
127-
TU = Base.nonmissingtype(eltype(U))
128-
Tt = Base.nonmissingtype(eltype(t))
129-
@assert length(t) == size(U, 2)
130-
non_missing_indices = collect(
131-
i for i in 1:length(t)
132-
if !any(ismissing, U[:, i]) && !ismissing(t[i])
133-
)
122+
function munge_data(U::AbstractMatrix, t::AbstractVector)
123+
TU = nonmissingtype(eltype(U))
124+
Tt = nonmissingtype(eltype(t))
125+
if TU === eltype(U) && Tt === eltype(t)
126+
return U, t
127+
end
134128

135-
U = hcat([TU.(U[:, i]) for i in non_missing_indices]...)
136-
t = Tt.([t[i] for i in non_missing_indices])
129+
@assert length(t) == size(U, 2)
130+
non_missing_mask = map(
131+
(uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachcol(U), t)
132+
U = convert(AbstractMatrix{TU}, U[:, non_missing_mask])
133+
t = convert(AbstractVector{Tt}, t[non_missing_mask])
137134

138135
return U, t
139136
end
140137

141138
function munge_data(U::AbstractArray{T, N}, t) where {T, N}
142-
TU = Base.nonmissingtype(eltype(U))
143-
Tt = Base.nonmissingtype(eltype(t))
144-
@assert length(t) == size(U, ndims(U))
145-
ax = axes(U)[1:(end - 1)]
146-
non_missing_indices = collect(
147-
i for i in 1:length(t)
148-
if !any(ismissing, U[ax..., i]) && !ismissing(t[i])
149-
)
150-
U = cat([TU.(U[ax..., i]) for i in non_missing_indices]...; dims = ndims(U))
151-
t = Tt.([t[i] for i in non_missing_indices])
139+
TU = nonmissingtype(eltype(U))
140+
Tt = nonmissingtype(eltype(t))
141+
if TU === eltype(U) && Tt === eltype(t)
142+
return U, t
143+
end
144+
145+
@assert length(t) == size(U, N)
146+
non_missing_mask = map(
147+
(uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachslice(U; dims = N), t)
148+
U = convert(AbstractArray{TU, N}, copy(selectdim(U, N, non_missing_mask)))
149+
t = convert(AbstractVector{Tt}, t[non_missing_mask])
152150

153151
return U, t
154152
end

test/interpolation_tests.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using StableRNGs
44
using Optim, ForwardDiff
55
using BenchmarkTools
66
using Unitful
7+
using JET
78

89
function test_interpolation_type(T)
910
@test T <: DataInterpolations.AbstractInterpolation
@@ -920,3 +921,22 @@ f_cubic_spline = c -> square(CubicSpline, c)
920921
@test ForwardDiff.derivative(f_quadratic_spline, 4.0) 8.0
921922
@test ForwardDiff.derivative(f_cubic_spline, 2.0) 4.0
922923
@test ForwardDiff.derivative(f_cubic_spline, 4.0) 8.0
924+
925+
@testset "munge_data" begin
926+
t0 = [0.1, 0.2, 0.3]
927+
u0 = ["A", "B", "C"]
928+
929+
for T in (String, Union{String, Missing}), dims in 1:3
930+
_u0 = convert(Array{T}, reshape(u0, ntuple(i -> i == dims ? 3 : 1, dims)))
931+
932+
u, t = @inferred(DataInterpolations.munge_data(_u0, t0))
933+
@test u isa Array{String, dims}
934+
@test t isa Vector{Float64}
935+
if T === String
936+
@test u === _u0
937+
@test t === t
938+
end
939+
940+
@test_call DataInterpolations.munge_data(_u0, t0)
941+
end
942+
end

0 commit comments

Comments
 (0)