Skip to content

Commit 8df4a41

Browse files
committed
lambertwbp_series(): simplify
- use Base.evalpoly() instead of horner macro reimplementation - use Val()-based dispatch instead of function generation
1 parent d846ea4 commit 8df4a41

File tree

1 file changed

+28
-53
lines changed

1 file changed

+28
-53
lines changed

src/lambertw.jl

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ Base.BigFloat(::Irrational{:ω}) = omega_const(BigFloat)
200200
# solve for α₂. We get α₂ = 0.
201201
# Compute array of coefficients μ in (4.22).
202202
# m[1] is μ₀
203-
function compute_branch_point_coeffs(T::Type{<:Number}, n::Integer)
203+
function lambertw_coeffs(T::Type{<:Number}, n::Integer)
204204
a = Vector{T}(undef, n)
205205
m = Vector{T}(undef, n)
206206

@@ -227,68 +227,43 @@ function compute_branch_point_coeffs(T::Type{<:Number}, n::Integer)
227227
return m
228228
end
229229

230-
const BRANCH_POINT_COEFFS_FLOAT64 = compute_branch_point_coeffs(Float64, 500)
230+
const LAMBERTW_COEFFS_FLOAT64 = lambertw_coeffs(Float64, 500)
231231

232-
# Base.Math.@horner requires literal coefficients
233-
# It cannot be used here because we have an array of computed coefficients
234-
function horner(x, coeffs::AbstractArray, n)
235-
n += 1
236-
ex = coeffs[n]
237-
for i = (n - 1):-1:2
238-
ex = :(muladd(t, $ex, $(coeffs[i])))
239-
end
240-
ex = :( t * $ex)
241-
return Expr(:block, :(t = $x), ex)
242-
end
243-
244-
# write functions that evaluate the branch point series
245-
# with `num_terms` number of terms.
246-
for (func_name, num_terms) in (
247-
(:wser3, 3), (:wser5, 5), (:wser7, 7), (:wser12, 12),
248-
(:wser19, 19), (:wser26, 26), (:wser32, 32),
249-
(:wser50, 50), (:wser100, 100), (:wser290, 290))
250-
iex = horner(:x, BRANCH_POINT_COEFFS_FLOAT64, num_terms)
251-
@eval function ($func_name)(x) $iex end
252-
end
232+
(lambertwbp_evalpoly(x::T, ::Val{N})::T) where {T<:Number, N} =
233+
# assume that Julia compiler is smart to decide for which N to unroll at compile time
234+
# note that we skip μ₀=-1
235+
evalpoly(x, ntuple(i -> LAMBERTW_COEFFS_FLOAT64[i+1], N-1))*x
253236

254-
# Converges to Float64 precision
237+
# how many coefficients of the series to use
238+
# to converge to Float64 precision for given x
255239
# We could get finer tuning by separating k=0, -1 branches.
256-
# Why is wser5 omitted ?
257-
# p is the argument to the series which is computed
258-
# from x before calling `branch_point_series`.
259-
function branch_point_series(p::Real, x::Real)
260-
x < 4e-11 && return wser3(p)
261-
x < 1e-5 && return wser7(p)
262-
x < 1e-3 && return wser12(p)
263-
x < 1e-2 && return wser19(p)
264-
x < 3e-2 && return wser26(p)
265-
x < 5e-2 && return wser32(p)
266-
x < 1e-1 && return wser50(p)
267-
x < 1.9e-1 && return wser100(p)
268-
x > 1 / MathConstants.e && throw(DomainError(x)) # radius of convergence
269-
return wser290(p) # good for x approx .32
240+
function lambertwbp_series_length(x::Real)
241+
x < 4e-11 && return 3
242+
# Why N = 5 is omitted?
243+
x < 1e-5 && return 7
244+
x < 1e-3 && return 12
245+
x < 1e-2 && return 19
246+
x < 3e-2 && return 26
247+
x < 5e-2 && return 32
248+
x < 1e-1 && return 50
249+
x < 1.9e-1 && return 100
250+
x > inv(MathConstants.e) && throw(DomainError(x)) # radius of convergence
251+
return 290 # good for x approx .32
270252
end
271253

272254
# These may need tuning.
273-
function branch_point_series(p::Complex{T}, z::Complex{T}) where T<:Real
274-
x = abs(z)
275-
x < 4e-11 && return wser3(p)
276-
x < 1e-5 && return wser7(p)
277-
x < 1e-3 && return wser12(p)
278-
x < 1e-2 && return wser19(p)
279-
x < 3e-2 && return wser26(p)
280-
x < 5e-2 && return wser32(p)
281-
x < 1e-1 && return wser50(p)
282-
x < 1.9e-1 && return wser100(p)
283-
x > 1 / MathConstants.e && throw(DomainError(x)) # radius of convergence
284-
return wser290(p)
285-
end
255+
lambertwbp_series_length(z::Complex) = lambertwbp_series_length(abs(z))
256+
257+
# p is the argument to the series which is computed from x,
258+
# see `_lambertwbp()`.
259+
lambertwbp_series(p::Number, x::Number) =
260+
lambertwbp_evalpoly(p, Val{lambertwbp_series_length(x)}())
286261

287262
_lambertwbp(x::Number, ::Val{0}) =
288-
branch_point_series(sqrt(2 * MathConstants.e * x), x)
263+
lambertwbp_series(sqrt(2 * MathConstants.e * x), x)
289264

290265
_lambertwbp(x::Number, ::Val{-1}) =
291-
branch_point_series(-sqrt(2 * MathConstants.e * x), x)
266+
lambertwbp_series(-sqrt(2 * MathConstants.e * x), x)
292267

293268
_lambertwbp(_::Number, k::Val) =
294269
throw(ArgumentError("lambertw() expansion about branch point for k=$k not implemented (only implemented for 0 and -1)."))

0 commit comments

Comments
 (0)