Skip to content

Commit 209762f

Browse files
Merge pull request #392 from SciML/dw/munge_data
Fix type inference and performance problems of `munge_data`
2 parents 5391a3c + f2a975f commit 209762f

File tree

2 files changed

+50
-33
lines changed

2 files changed

+50
-33
lines changed

src/interpolation_utils.jl

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -104,51 +104,49 @@ function quadratic_spline_params(t::AbstractVector, sc::AbstractVector)
104104
end
105105

106106
# helper function for data manipulation
107-
function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real})
108-
return u, t
109-
end
110-
111107
function munge_data(u::AbstractVector, t::AbstractVector)
112-
Tu = Base.nonmissingtype(eltype(u))
113-
Tt = Base.nonmissingtype(eltype(t))
114-
@assert length(t) == length(u)
115-
non_missing_indices = collect(
116-
i for i in 1:length(t)
117-
if !ismissing(u[i]) && !ismissing(t[i])
118-
)
108+
Tu = nonmissingtype(eltype(u))
109+
Tt = nonmissingtype(eltype(t))
110+
if Tu === eltype(u) && Tt === eltype(t)
111+
return u, t
112+
end
119113

120-
u = Tu.([u[i] for i in non_missing_indices])
121-
t = Tt.([t[i] for i in non_missing_indices])
114+
@assert length(t) == length(u)
115+
non_missing_mask = map((ui, ti) -> !ismissing(ui) && !ismissing(ti), u, t)
116+
u = convert(AbstractVector{Tu}, u[non_missing_mask])
117+
t = convert(AbstractVector{Tt}, t[non_missing_mask])
122118

123119
return u, t
124120
end
125121

126-
function munge_data(U::StridedMatrix, t::AbstractVector)
127-
TU = Base.nonmissingtype(eltype(U))
128-
Tt = Base.nonmissingtype(eltype(t))
129-
@assert length(t) == size(U, 2)
130-
non_missing_indices = collect(
131-
i for i in 1:length(t)
132-
if !any(ismissing, U[:, i]) && !ismissing(t[i])
133-
)
122+
function munge_data(U::AbstractMatrix, t::AbstractVector)
123+
TU = nonmissingtype(eltype(U))
124+
Tt = nonmissingtype(eltype(t))
125+
if TU === eltype(U) && Tt === eltype(t)
126+
return U, t
127+
end
134128

135-
U = hcat([TU.(U[:, i]) for i in non_missing_indices]...)
136-
t = Tt.([t[i] for i in non_missing_indices])
129+
@assert length(t) == size(U, 2)
130+
non_missing_mask = map(
131+
(uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachcol(U), t)
132+
U = convert(AbstractMatrix{TU}, U[:, non_missing_mask])
133+
t = convert(AbstractVector{Tt}, t[non_missing_mask])
137134

138135
return U, t
139136
end
140137

141138
function munge_data(U::AbstractArray{T, N}, t) where {T, N}
142-
TU = Base.nonmissingtype(eltype(U))
143-
Tt = Base.nonmissingtype(eltype(t))
144-
@assert length(t) == size(U, ndims(U))
145-
ax = axes(U)[1:(end - 1)]
146-
non_missing_indices = collect(
147-
i for i in 1:length(t)
148-
if !any(ismissing, U[ax..., i]) && !ismissing(t[i])
149-
)
150-
U = cat([TU.(U[ax..., i]) for i in non_missing_indices]...; dims = ndims(U))
151-
t = Tt.([t[i] for i in non_missing_indices])
139+
TU = nonmissingtype(eltype(U))
140+
Tt = nonmissingtype(eltype(t))
141+
if TU === eltype(U) && Tt === eltype(t)
142+
return U, t
143+
end
144+
145+
@assert length(t) == size(U, N)
146+
non_missing_mask = map(
147+
(uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachslice(U; dims = N), t)
148+
U = convert(AbstractArray{TU, N}, copy(selectdim(U, N, non_missing_mask)))
149+
t = convert(AbstractVector{Tt}, t[non_missing_mask])
152150

153151
return U, t
154152
end

test/interpolation_tests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,3 +940,22 @@ f_cubic_spline = c -> square(CubicSpline, c)
940940
@test ForwardDiff.derivative(f_quadratic_spline, 4.0) 8.0
941941
@test ForwardDiff.derivative(f_cubic_spline, 2.0) 4.0
942942
@test ForwardDiff.derivative(f_cubic_spline, 4.0) 8.0
943+
944+
@testset "munge_data" begin
945+
t0 = [0.1, 0.2, 0.3]
946+
u0 = ["A", "B", "C"]
947+
iszero_allocations(u, t) = iszero(@allocated(DataInterpolations.munge_data(u, t)))
948+
949+
for T in (String, Union{String, Missing}), dims in 1:3
950+
_u0 = convert(Array{T}, reshape(u0, ntuple(i -> i == dims ? 3 : 1, dims)))
951+
952+
u, t = @inferred(DataInterpolations.munge_data(_u0, t0))
953+
@test u isa Array{String, dims}
954+
@test t isa Vector{Float64}
955+
if T === String
956+
@test iszero_allocations(_u0, t0)
957+
@test u === _u0
958+
@test t === t
959+
end
960+
end
961+
end

0 commit comments

Comments
 (0)