Skip to content

Commit 5c2fecf

Browse files
committed
reduce array allocations
1 parent 463eb0a commit 5c2fecf

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

src/modelframe.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,26 @@ end
5353
_missing_omit(x::AbstractVector{T}) where T = copyto!(similar(x, nonmissingtype(T)), x)
5454
_missing_omit(x::AbstractVector, rows) = _missing_omit(view(x, rows))
5555

56-
function missing_omit(d::T) where T<:ColumnTable
56+
function _maybe_missing_omit(d::T) where T<:ColumnTable
5757
nonmissings = trues(length(first(d)))
58-
for col in d
59-
_nonmissing!(nonmissings, col)
60-
end
61-
d_nonmissing = if all(nonmissings)
62-
map(_missing_omit, d)
58+
if any(eltype(col) >: Missing for col in d)
59+
for col in d
60+
_nonmissing!(nonmissings, col)
61+
end
62+
d_nonmissing = if all(nonmissings)
63+
map(_missing_omit, d)
64+
else
65+
rows = findall(nonmissings)
66+
map(Base.Fix2(_missing_omit, rows), d)
67+
end
68+
return d_nonmissing, nonmissings
6369
else
64-
rows = findall(nonmissings)
65-
map(Base.Fix2(_missing_omit, rows), d)
70+
return d, nonmissings
6671
end
67-
d_nonmissing, nonmissings
6872
end
6973

70-
missing_omit(data::T, formula::AbstractTerm) where T<:ColumnTable =
71-
missing_omit(NamedTuple{tuple(termvars(formula)...)}(data))
74+
_maybe_missing_omit(data::T, formula::AbstractTerm) where T<:ColumnTable =
75+
_maybe_missing_omit(NamedTuple{tuple(termvars(formula)...)}(data))
7276

7377
function ModelFrame(f::FormulaTerm, data::ColumnTable;
7478
model::Type{M}=StatisticalModel, contrasts=Dict{Symbol,Any}()) where M
@@ -78,7 +82,7 @@ function ModelFrame(f::FormulaTerm, data::ColumnTable;
7882
throw(ArgumentError(msg))
7983
end
8084

81-
data, _ = missing_omit(data, f)
85+
data, _ = _maybe_missing_omit(data, f)
8286

8387
sch = schema(f, data, contrasts)
8488
f = apply_schema(f, sch, M)

src/statsmodel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ function StatsBase.predict(mm::TableRegressionModel, data; kwargs...)
174174
throw(ArgumentError("expected data in a Table, got $(typeof(data))"))
175175

176176
f = mm.mf.f
177-
cols, nonmissings = missing_omit(columntable(data), f.rhs)
177+
cols, nonmissings = _maybe_missing_omit(columntable(data), f.rhs)
178178
new_x = modelcols(f.rhs, cols)
179179
y_pred = predict(mm.model, reshape(new_x, size(new_x, 1), :);
180180
kwargs...)

0 commit comments

Comments
 (0)