Skip to content

Commit b646750

Browse files
committed
Reduce allocations by not calling dotted version of withUnit or ustrip
Instead, attach units to step and acc with two separate calls to withUnit, and determine the number type N of method coefs intead of first adding units and then stripping them. A fair amount of allocations still remain for Unitful calls, though.
1 parent 2b447e0 commit b646750

File tree

1 file changed

+32
-15
lines changed

1 file changed

+32
-15
lines changed

src/methods.jl

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ function _compute_estimate(
261261
# https://github.com/JuliaLang/julia/issues/39151.
262262
#
263263
# We strip units because the estimate coefficients are just weights for values of f.
264-
_coefs = ustrip.(T.(coefs))
264+
N = numType(T)
265+
_coefs = N.(coefs)
265266
return sum(fs .* _coefs) ./ T(step) ^ Q
266267
end
267268

@@ -358,27 +359,26 @@ estimate of the derivative.
358359
function estimate_step(
359360
m::UnadaptedFiniteDifferenceMethod, f::TF, x::T,
360361
) where {TF,T<:Number}
361-
step, acc = withUnit.(
362-
unit(x),
363-
_compute_step_acc_default(m, x)
364-
)
362+
step, acc = _compute_step_acc_default(m, x)
363+
xunit = unit(x)
364+
step = withUnit(xunit, step)
365+
acc = withUnit(xunit,acc)
365366
return _limit_step(m, x, step, acc)
366367
end
367368
function estimate_step(
368369
m::AdaptedFiniteDifferenceMethod{P,Q}, f::TF, x::T,
369370
) where {P,Q,TF,T<:Number}
370371
∇f_magnitude, f_magnitude = _estimate_magnitudes(m.bound_estimator, f, x)
371-
step, acc = withUnit.(
372-
(
373-
unit(x),
374-
unit(first(f(x))) / unit(x) ^ Q
375-
),
372+
xunit = unit(x)
373+
dfunit = unit(first(f(x))) / unit(x) ^ Q
374+
step, acc =
376375
if ∇f_magnitude == withUnit(unit(∇f_magnitude),0.0) || f_magnitude == withUnit(unit(f_magnitude), 0.0)
377376
_compute_step_acc_default(m, x)
378377
else
379378
_compute_step_acc(m, ∇f_magnitude, eps(f_magnitude))
380379
end
381-
)
380+
step = withUnit(xunit, step)
381+
acc = withUnit(dfunit, acc)
382382
return _limit_step(m, x, step, acc)
383383
end
384384

@@ -431,10 +431,8 @@ function _limit_step(
431431
end
432432
# Second, prevent very large step sizes, which can occur for high-order methods or
433433
# slowly-varying functions.
434-
step_default, _ = withUnit.(
435-
xunit,
436-
_compute_step_acc_default(m, x)
437-
)
434+
step_default, _ = _compute_step_acc_default(m, x)
435+
step_default = withUnit(xunit, step_default)
438436
step_max_default = 1000step_default
439437
if step > step_max_default
440438
step = step_max_default
@@ -628,3 +626,22 @@ function withUnit(targetUnit, value)
628626
end # if
629627

630628
end # function
629+
630+
"""
631+
Retrieves the number type of a quantity, or returns the type itself in the case of a raw number.
632+
"""
633+
function numType(x::Number)
634+
typeof(x)
635+
end # function
636+
637+
function numType(x::Type{<:Number})
638+
x
639+
end
640+
641+
function numType(x::Unitful.AbstractQuantity)
642+
Unitful.numtype(typeof(x))
643+
end # function
644+
645+
function numType(x::Type{<:Unitful.AbstractQuantity})
646+
Unitful.numtype(x)
647+
end

0 commit comments

Comments
 (0)