11import Base: convert
2- # export lambertw, lambertwbp
3- using Compat
4-
5- const euler =
6- if isdefined (Base, :MathConstants )
7- Base. MathConstants. e
8- else
9- e
10- end
112
12- const omega_const_bf_ = Ref {BigFloat} ()
13-
14- function __init__ ()
15- omega_const_bf_[] =
16- parse (BigFloat," 0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194" )
17- end
3+ using Compat
4+ import Compat. MathConstants # For clarity, we use MathConstants.e for Euler's number
185
196# ### Lambert W function ####
207
21- const LAMBERTW_USE_NAN = false
22-
23- macro baddomain (v)
24- if LAMBERTW_USE_NAN
25- return :(return (NaN ))
26- else
27- return esc (:(throw (DomainError ($ v))))
28- end
29- end
30-
31- # Use Halley's root-finding method to find x = lambertw(z) with
32- # initial point x.
33- function _lambertw (z:: T , x:: T ) where T <: Number
8+ # Use Halley's root-finding method to find
9+ # x = lambertw(z) with initial point x.
10+ function _lambertw (z:: T , x:: T , maxits) where T <: Number
3411 two_t = convert (T,2 )
3512 lastx = x
3613 lastdiff = zero (T)
37- for i in 1 : 100
14+ converged:: Bool = false
15+ for i in 1 : maxits
3816 ex = exp (x)
3917 xexz = x * ex - z
4018 x1 = x + 1
41- x = x - xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) )
19+ x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) )
4220 xdiff = abs (lastx - x)
43- xdiff <= 2 * eps (abs (lastx)) && break
44- lastdiff == diff && break
21+ if xdiff <= 3 * eps (abs (lastx)) || lastdiff == xdiff # second condition catches two-value cycle
22+ converged = true
23+ break
24+ end
4525 lastx = x
4626 lastdiff = xdiff
4727 end
48- x
28+ converged || warn (" lambertw with z=" , z, " did not converge in " , maxits, " iterations." )
29+ return x
4930end
5031
5132# ## Real z ###
5233
5334# Real x, k = 0
54-
35+ # This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf.
5536# The fancy initial condition selection does not seem to help speed, but we leave it for now.
56- function lambertwk0 (x:: T ):: T where T<: AbstractFloat
57- x == Inf && return Inf
37+ function lambertwk0 (x:: T , maxits):: T where T<: AbstractFloat
38+ isnan (x) && return (NaN )
39+ x == Inf && return Inf # appears to return convert(BigFloat,Inf) for x == BigFloat(Inf)
5840 one_t = one (T)
59- oneoe = - one_t/ convert (T,euler)
41+ oneoe = - one_t/ convert (T,MathConstants . e) # The branch point
6042 x == oneoe && return - one_t
43+ oneoe <= x || throw (DomainError (x))
6144 itwo_t = 1 / convert (T,2 )
62- oneoe <= x || @baddomain (x)
6345 if x > one_t
6446 lx = log (x)
6547 llx = log (lx)
6648 x1 = lx - llx - log (one_t - llx/ lx) * itwo_t
6749 else
6850 x1 = (567 // 1000 ) * x
6951 end
70- _lambertw (x,x1 )
52+ return _lambertw (x, x1, maxits )
7153end
7254
7355# Real x, k = -1
74- function _lambertwkm1 (x:: T ) where T<: Real
75- oneoe = - one (T)/ convert (T,euler )
76- x == oneoe && return - one (T)
77- oneoe <= x || @baddomain (x)
78- x == zero (T) && return - convert (T,Inf )
79- x < zero (T) || @baddomain (x )
80- _lambertw (x,log (- x))
56+ function lambertwkm1 (x:: T , maxits ) where T<: Real
57+ oneoe = - one (T)/ convert (T,MathConstants . e )
58+ x == oneoe && return - one (T) # W approaches -1 as x -> -1/e from above
59+ oneoe <= x || throw ( DomainError (x)) # branch domain exludes x < -1/e
60+ x == zero (T) && return - convert (T,Inf ) # W decreases w/o bound as x -> 0 from below
61+ x < zero (T) || throw ( DomainError (x) )
62+ return _lambertw (x, log (- x), maxits )
8163end
8264
83-
8465"""
85- lambertw(z::Complex{T}, k::V=0) where {T<:Real, V<:Integer}
86- lambertw(z::T, k::V=0) where {T<:Real, V<:Integer}
66+ lambertw(z::Complex{T}, k::V=0, maxits=1000 ) where {T<:Real, V<:Integer}
67+ lambertw(z::T, k::V=0, maxits=1000 ) where {T<:Real, V<:Integer}
8768
8869Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be
8970either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e,0]` and the
9071domain of the branch `k = 0` is `[-1/e,Inf]`. For `Complex` `z`, and all `k`, the domain is
91- the complex plane.
72+ the complex plane. When using root finding to compute `W`, a value for `W` is returned
73+ with a warning if it has not converged after `maxits` iterations.
9274
9375```jldoctest
9476julia> lambertw(-1/e,-1)
@@ -107,33 +89,31 @@ julia> lambertw(Complex(-10.0,3.0), 4)
10789-0.9274337508660128 + 26.37693445371142im
10890```
10991
110- !!! note
111- The constant `LAMBERTW_USE_NAN` at the top of the source file controls whether arguments
112- outside the domain throw `DomainError` or return `NaN`. The default is `DomainError`.
11392"""
114- function lambertw (x:: Real , k:: Integer )
115- k == 0 && return lambertwk0 (x)
116- k == - 1 && return _lambertwkm1 (x)
117- @baddomain (k) # more informative message like below ?
118- # error("lambertw: real x must have k == 0 or k == -1")
93+ lambertw (z, k:: Integer = 0 , maxits:: Integer = 1000 ) = lambertw_ (z, k, maxits)
94+
95+ function lambertw_ (x:: Real , k, maxits)
96+ k == 0 && return lambertwk0 (x, maxits)
97+ k == - 1 && return lambertwkm1 (x, maxits)
98+ throw (DomainError (k, " lambertw: real x must have branch k == 0 or k == -1" ))
11999end
120100
121- function lambertw (x:: Union{Integer,Rational} , k:: Integer )
101+ function lambertw_ (x:: Union{Integer,Rational} , k, maxits )
122102 if k == 0
123103 x == 0 && return float (zero (x))
124- x == 1 && return convert (typeof (float (x)),omega) # must be more efficient way
104+ x == 1 && return convert (typeof (float (x)), omega) # must be a more efficient way
125105 end
126- lambertw (float (x),k )
106+ return lambertw_ (float (x), k, maxits )
127107end
128108
129109# ## Complex z ###
130110
131111# choose initial value inside correct branch for root finding
132- function lambertw (z:: Complex{T} , k:: Integer ) where T<: Real
112+ function lambertw_ (z:: Complex{T} , k, maxits ) where T<: Real
133113 one_t = one (T)
134114 local w:: Complex{T}
135115 pointseven = 7 // 10
136- if abs (z) <= one_t/ convert (T,euler )
116+ if abs (z) <= one_t/ convert (T,MathConstants . e )
137117 if z == 0
138118 k == 0 && return z
139119 return complex (- convert (T,Inf ),zero (T))
@@ -157,35 +137,27 @@ function lambertw(z::Complex{T}, k::Integer) where T<:Real
157137 w = log (z)
158138 k != 0 ? w += complex (0 , 2 * k* pi ) : nothing
159139 end
160- _lambertw (z,w )
140+ return _lambertw (z, w, maxits )
161141end
162142
163- lambertw (z:: Complex{T} , k:: Integer ) where T<: Integer = lambertw (float (z),k)
143+ lambertw_ (z:: Complex{T} , k, maxits) where T<: Integer = lambertw_ (float (z), k, maxits)
144+ lambertw_ (n:: Irrational , k, maxits) = lambertw_ (float (n), k, maxits)
164145
165146# lambertw(e + 0im,k) is ok for all k
166- # function lambertw(::Irrational{:e}, k::T) where T<:Integer
167- function lambertw (:: typeof (euler ), k:: T ) where T <: Integer
147+ # Maybe this should return a float. But, this should cause no type instability in any case
148+ function lambertw_ (:: typeof (MathConstants . e ), k, maxits)
168149 k == 0 && return 1
169- @baddomain (k )
150+ throw ( DomainError (k) )
170151end
171152
172- # Maybe this should return a float
173- lambertw (:: typeof (euler)) = 1
174- # lambertw(::Irrational{:e}) = 1
175-
176- # lambertw{T<:Number}(x::T) = lambertw(x,0)
177- lambertw (x:: Number ) = lambertw (x,0 )
178-
179- lambertw (n:: Irrational , args:: Integer... ) = lambertw (float (n),args... )
180-
181153# ## omega constant ###
182154
183155const omega_const_ = 0.567143290409783872999968662210355
184156# The BigFloat `omega_const_bf_` is set via a literal in the function __init__ to prevent a segfault
185157
186158# maybe compute higher precision. converges very quickly
187159function omega_const (:: Type{BigFloat} )
188- @compat precision (BigFloat) <= 256 && return omega_const_bf_[]
160+ precision (BigFloat) <= 256 && return omega_const_bf_[]
189161 myeps = eps (BigFloat)
190162 oc = omega_const_bf_[]
191163 for i in 1 : 100
200172 omega
201173 ω
202174
203- A constant defined by `ω exp(ω) = 1`.
175+ The constant defined by `ω exp(ω) = 1`.
204176
205177```jldoctest
206178julia> ω
@@ -219,7 +191,7 @@ julia> big(omega)
219191const ω = Irrational {:ω} ()
220192@doc (@doc ω) omega = ω
221193
222- # The following three lines may be removed when support for v0.6 is dropped
194+ # The following two lines may be removed when support for v0.6 is dropped
223195Base. convert (:: Type{AbstractFloat} , o:: Irrational{:ω} ) = Float64 (o)
224196Base. convert (:: Type{Float16} , o:: Irrational{:ω} ) = Float16 (o)
225197Base. convert (:: Type{T} , o:: Irrational{:ω} ) where T <: Number = T (o)
@@ -236,7 +208,7 @@ Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat)
236208# (4.23) and (4.24) for all μ are also given. This code implements the
237209# recursion relations.
238210
239- # (4.23) and (4.24) give zero based coefficients
211+ # (4.23) and (4.24) give zero based coefficients.
240212cset (a,i,v) = a[i+ 1 ] = v
241213cget (a,i) = a[i+ 1 ]
242214
@@ -247,7 +219,7 @@ function compa(k,m,a)
247219 sum0 += cget (m,j) * cget (m,k+ 1 - j)
248220 end
249221 cset (a,k,sum0)
250- sum0
222+ return sum0
251223end
252224
253225# (4.23)
@@ -256,7 +228,7 @@ function compm(k,m,a)
256228 mk = (kt- 1 )/ (kt+ 1 ) * (cget (m,k- 2 )/ 2 + cget (a,k- 2 )/ 4 ) -
257229 cget (a,k)/ 2 - cget (m,k- 1 )/ (kt+ 1 )
258230 cset (m,k,mk)
259- mk
231+ return mk
260232end
261233
262234# We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and
@@ -283,19 +255,21 @@ end
283255
284256const LAMWMU_FLOAT64 = lamwcoeff (Float64,500 )
285257
286- function horner (x, p:: AbstractArray ,n)
258+ # Base.Math.@horner requires literal coefficients
259+ # But, we have an array `p` of computed coefficients
260+ function horner (x, p:: AbstractArray , n)
287261 n += 1
288262 ex = p[n]
289263 for i = n- 1 : - 1 : 2
290- ex = :($ (p[i]) + t * $ ex )
264+ ex = :(muladd (t, $ ex, $ (p[i])) )
291265 end
292266 ex = :( t * $ ex)
293- Expr (:block , :(t = $ x), ex)
267+ return Expr (:block , :(t = $ x), ex)
294268end
295269
296270function mkwser (name, n)
297271 iex = horner (:x ,LAMWMU_FLOAT64,n)
298- :(function ($ name)(x) $ iex end )
272+ return :(function ($ name)(x) $ iex end )
299273end
300274
301275eval (mkwser (:wser3 , 3 ))
@@ -320,7 +294,7 @@ function wser(p,x)
320294 x < 5e-2 && return wser32 (p)
321295 x < 1e-1 && return wser50 (p)
322296 x < 1.9e-1 && return wser100 (p)
323- x > 1 / euler && @baddomain (x ) # radius of convergence
297+ x > 1 / MathConstants . e && throw ( DomainError (x) ) # radius of convergence
324298 return wser290 (p) # good for x approx .32
325299end
326300
@@ -335,28 +309,28 @@ function wser(p::Complex{T},z) where T<:Real
335309 x < 5e-2 && return wser32 (p)
336310 x < 1e-1 && return wser50 (p)
337311 x < 1.9e-1 && return wser100 (p)
338- x > 1 / euler && @baddomain (x ) # radius of convergence
312+ x > 1 / MathConstants . e && throw ( DomainError (x) ) # radius of convergence
339313 return wser290 (p)
340314end
341315
342316@inline function _lambertw0 (x) # 1 + W(-1/e + x) , k = 0
343- ps = 2 * euler * x;
317+ ps = 2 * MathConstants . e * x;
344318 p = sqrt (ps)
345- wser (p,x)
319+ return wser (p,x)
346320end
347321
348322@inline function _lambertwm1 (x) # 1 + W(-1/e + x) , k = -1
349- ps = 2 * euler * x;
323+ ps = 2 * MathConstants . e * x;
350324 p = - sqrt (ps)
351- wser (p,x)
325+ return wser (p,x)
352326end
353327
354328"""
355329 lambertwbp(z,k=0)
356330
357- Accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`.
358- Accurate to Float64 precision for abs(z) < 0.32.
359- If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. `lambertwbp` is vectorized.
331+ Compute accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`.
332+ The result is accurate to Float64 precision for abs(z) < 0.32.
333+ If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned.
360334
361335```jldoctest
362336julia> lambertw(-1/e + 1e-18, -1)
@@ -378,9 +352,7 @@ julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1))
378352function lambertwbp (x:: Number ,k:: Integer )
379353 k == 0 && return _lambertw0 (x)
380354 k == - 1 && return _lambertwm1 (x)
381- error ( " expansion about branch point only implemented for k = 0 and -1" )
355+ throw ( ArgumentError ( " expansion about branch point only implemented for k = 0 and -1. " ) )
382356end
383357
384358lambertwbp (x:: Number ) = _lambertw0 (x)
385-
386- nothing
0 commit comments