Skip to content

Commit f33897d

Browse files
committed
lambertw: use dispatch to compute diff. branches
1 parent 2aefee4 commit f33897d

File tree

1 file changed

+14
-24
lines changed

1 file changed

+14
-24
lines changed

src/lambertw.jl

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,12 @@ end
4444

4545
### Real x
4646

47-
function _lambertw(x::Real, k::Integer, maxits::Integer)
48-
k == 0 && return lambertw_branch_zero(x, maxits)
49-
k == -1 && return lambertw_branch_one(x, maxits)
50-
throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1"))
51-
end
47+
_lambertw(x::Real, k::Integer, maxits::Integer) = _lambertw(x, Val(Int(k)), maxits)
5248

5349
# Real x, k = 0
5450
# There is a magic number here. It could be noted, or possibly removed.
5551
# In particular, the fancy initial condition selection does not seem to help speed.
56-
function lambertw_branch_zero(x::T, maxits::Integer) where T<:Real
52+
function _lambertw(x::T, ::Val{0}, maxits::Integer) where T<:Real
5753
isfinite(x) || return x
5854
one_t = one(T)
5955
oneoe = -inv(convert(T, MathConstants.e)) # The branch point
@@ -71,7 +67,7 @@ function lambertw_branch_zero(x::T, maxits::Integer) where T<:Real
7167
end
7268

7369
# Real x, k = -1
74-
function lambertw_branch_one(x::T, maxits::Integer) where T<:Real
70+
function _lambertw(x::T, ::Val{-1}, maxits::Integer) where T<:Real
7571
oneoe = -inv(convert(T, MathConstants.e))
7672
x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above
7773
oneoe < x || throw(DomainError(x)) # branch domain exludes x < -1/e
@@ -80,6 +76,9 @@ function lambertw_branch_one(x::T, maxits::Integer) where T<:Real
8076
return lambertw_root_finding(x, log(-x), maxits)
8177
end
8278

79+
_lambertw(x::Real, k::Val, maxits::Integer) =
80+
throw(DomainError(x, "lambertw: for branch k=$k not defined, real x must have branch k == 0 or k == -1"))
81+
8382
### Complex z
8483

8584
_lambertw(z::Complex{<:Integer}, k::Integer, maxits::Integer) = _lambertw(float(z), k, maxits)
@@ -285,17 +284,14 @@ function branch_point_series(p::Complex{T}, z::Complex{T}) where T<:Real
285284
return wser290(p)
286285
end
287286

288-
function _lambertw0(x::Number) # 1 + W(-1/e + x) , k = 0
289-
ps = 2 * MathConstants.e * x
290-
series_arg = sqrt(ps)
291-
branch_point_series(series_arg, x)
292-
end
287+
_lambertwbp(x::Number, ::Val{0}) =
288+
branch_point_series(sqrt(2 * MathConstants.e * x), x)
293289

294-
function _lambertwm1(x::Number) # 1 + W(-1/e + x) , k = -1
295-
ps = 2 * MathConstants.e * x
296-
series_arg = -sqrt(ps)
297-
branch_point_series(series_arg, x)
298-
end
290+
_lambertwbp(x::Number, ::Val{-1}) =
291+
branch_point_series(-sqrt(2 * MathConstants.e * x), x)
292+
293+
_lambertwbp(_::Number, k::Val) =
294+
throw(ArgumentError("lambertw() expansion about branch point for k=$k not implemented (only implemented for 0 and -1)."))
299295

300296
"""
301297
lambertwbp(z, k=0)
@@ -323,10 +319,4 @@ julia> convert(Float64, (lambertw(-BigFloat(1)/e + BigFloat(10)^(-18), -1) + 1))
323319
The loss of precision in `lambertw` is analogous to the loss of precision
324320
in computing the `sqrt(1-x)` for `x` close to `1`.
325321
"""
326-
function lambertwbp(x::Number, k::Integer)
327-
k == 0 && return _lambertw0(x)
328-
k == -1 && return _lambertwm1(x)
329-
throw(ArgumentError("expansion about branch point only implemented for k = 0 and -1."))
330-
end
331-
332-
lambertwbp(x::Number) = _lambertw0(x)
322+
lambertwbp(x::Number, k::Integer=0) = _lambertwbp(x, Val(Int(k)))

0 commit comments

Comments
 (0)