1
1
import Base: convert
2
- #export lambertw, lambertwbp
3
- using Compat
4
-
5
- const euler =
6
- if isdefined(Base, :MathConstants)
7
- Base.MathConstants.e
8
- else
9
- e
10
- end
11
2
12
- const omega_const_bf_ = Ref{BigFloat}()
13
-
14
- function __init__()
15
- omega_const_bf_[] =
16
- parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194")
17
- end
3
+ using Compat
4
+ import Compat.MathConstants # For clarity, we use MathConstants.e for Euler's number
18
5
19
6
#### Lambert W function ####
20
7
21
- const LAMBERTW_USE_NAN = false
22
-
23
- macro baddomain(v)
24
- if LAMBERTW_USE_NAN
25
- return :(return(NaN))
26
- else
27
- return esc(:(throw(DomainError($v))))
28
- end
29
- end
30
-
31
- # Use Halley's root-finding method to find x = lambertw(z) with
32
- # initial point x.
33
- function _lambertw(z::T, x::T) where T <: Number
8
+ # Use Halley's root-finding method to find
9
+ # x = lambertw(z) with initial point x.
10
+ function _lambertw(z::T, x::T, maxits) where T <: Number
34
11
two_t = convert(T,2)
35
12
lastx = x
36
13
lastdiff = zero(T)
37
- for i in 1:100
14
+ converged::Bool = false
15
+ for i in 1:maxits
38
16
ex = exp(x)
39
17
xexz = x * ex - z
40
18
x1 = x + 1
41
- x = x - xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) )
19
+ x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) )
42
20
xdiff = abs(lastx - x)
43
- xdiff <= 2*eps(abs(lastx)) && break
44
- lastdiff == diff && break
21
+ if xdiff <= 3*eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle
22
+ converged = true
23
+ break
24
+ end
45
25
lastx = x
46
26
lastdiff = xdiff
47
27
end
48
- x
28
+ converged || warn("lambertw with z=", z, " did not converge in ", maxits, " iterations.")
29
+ return x
49
30
end
50
31
51
32
### Real z ###
52
33
53
34
# Real x, k = 0
54
-
35
+ # This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf.
55
36
# The fancy initial condition selection does not seem to help speed, but we leave it for now.
56
- function lambertwk0(x::T)::T where T<:AbstractFloat
57
- x == Inf && return Inf
37
+ function lambertwk0(x::T, maxits)::T where T<:AbstractFloat
38
+ isnan(x) && return(NaN)
39
+ x == Inf && return Inf # appears to return convert(BigFloat,Inf) for x == BigFloat(Inf)
58
40
one_t = one(T)
59
- oneoe = -one_t/convert(T,euler)
41
+ oneoe = -one_t/convert(T,MathConstants.e) # The branch point
60
42
x == oneoe && return -one_t
43
+ oneoe <= x || throw(DomainError(x))
61
44
itwo_t = 1/convert(T,2)
62
- oneoe <= x || @baddomain(x)
63
45
if x > one_t
64
46
lx = log(x)
65
47
llx = log(lx)
66
48
x1 = lx - llx - log(one_t - llx/lx) * itwo_t
67
49
else
68
50
x1 = (567//1000) * x
69
51
end
70
- _lambertw(x,x1 )
52
+ return _lambertw(x, x1, maxits )
71
53
end
72
54
73
55
# Real x, k = -1
74
- function _lambertwkm1 (x::T) where T<:Real
75
- oneoe = -one(T)/convert(T,euler )
76
- x == oneoe && return -one(T)
77
- oneoe <= x || @baddomain(x)
78
- x == zero(T) && return -convert(T,Inf)
79
- x < zero(T) || @baddomain(x )
80
- _lambertw(x,log(-x))
56
+ function lambertwkm1 (x::T, maxits ) where T<:Real
57
+ oneoe = -one(T)/convert(T,MathConstants.e )
58
+ x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above
59
+ oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e
60
+ x == zero(T) && return -convert(T,Inf) # W decreases w/o bound as x -> 0 from below
61
+ x < zero(T) || throw(DomainError(x) )
62
+ return _lambertw(x, log(-x), maxits )
81
63
end
82
64
83
-
84
65
"""
85
- lambertw(z::Complex{T}, k::V=0) where {T<:Real, V<:Integer}
86
- lambertw(z::T, k::V=0) where {T<:Real, V<:Integer}
66
+ lambertw(z::Complex{T}, k::V=0, maxits=1000 ) where {T<:Real, V<:Integer}
67
+ lambertw(z::T, k::V=0, maxits=1000 ) where {T<:Real, V<:Integer}
87
68
88
69
Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be
89
70
either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e,0]` and the
90
71
domain of the branch `k = 0` is `[-1/e,Inf]`. For `Complex` `z`, and all `k`, the domain is
91
- the complex plane.
72
+ the complex plane. When using root finding to compute `W`, a value for `W` is returned
73
+ with a warning if it has not converged after `maxits` iterations.
92
74
93
75
```jldoctest
94
76
julia> lambertw(-1/e,-1)
@@ -107,33 +89,31 @@ julia> lambertw(Complex(-10.0,3.0), 4)
107
89
-0.9274337508660128 + 26.37693445371142im
108
90
```
109
91
110
- !!! note
111
- The constant `LAMBERTW_USE_NAN` at the top of the source file controls whether arguments
112
- outside the domain throw `DomainError` or return `NaN`. The default is `DomainError`.
113
92
"""
114
- function lambertw(x::Real, k::Integer)
115
- k == 0 && return lambertwk0(x)
116
- k == -1 && return _lambertwkm1(x)
117
- @baddomain(k) # more informative message like below ?
118
- # error("lambertw: real x must have k == 0 or k == -1")
93
+ lambertw(z, k::Integer=0, maxits::Integer=1000) = lambertw_(z, k, maxits)
94
+
95
+ function lambertw_(x::Real, k, maxits)
96
+ k == 0 && return lambertwk0(x, maxits)
97
+ k == -1 && return lambertwkm1(x, maxits)
98
+ throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1"))
119
99
end
120
100
121
- function lambertw (x::Union{Integer,Rational}, k::Integer )
101
+ function lambertw_ (x::Union{Integer,Rational}, k, maxits )
122
102
if k == 0
123
103
x == 0 && return float(zero(x))
124
- x == 1 && return convert(typeof(float(x)),omega) # must be more efficient way
104
+ x == 1 && return convert(typeof(float(x)), omega) # must be a more efficient way
125
105
end
126
- lambertw (float(x),k )
106
+ return lambertw_ (float(x), k, maxits )
127
107
end
128
108
129
109
### Complex z ###
130
110
131
111
# choose initial value inside correct branch for root finding
132
- function lambertw (z::Complex{T}, k::Integer ) where T<:Real
112
+ function lambertw_ (z::Complex{T}, k, maxits ) where T<:Real
133
113
one_t = one(T)
134
114
local w::Complex{T}
135
115
pointseven = 7//10
136
- if abs(z) <= one_t/convert(T,euler )
116
+ if abs(z) <= one_t/convert(T,MathConstants.e )
137
117
if z == 0
138
118
k == 0 && return z
139
119
return complex(-convert(T,Inf),zero(T))
@@ -157,35 +137,27 @@ function lambertw(z::Complex{T}, k::Integer) where T<:Real
157
137
w = log(z)
158
138
k != 0 ? w += complex(0, 2*k*pi) : nothing
159
139
end
160
- _lambertw(z,w )
140
+ return _lambertw(z, w, maxits )
161
141
end
162
142
163
- lambertw(z::Complex{T}, k::Integer) where T<:Integer = lambertw(float(z),k)
143
+ lambertw_(z::Complex{T}, k, maxits) where T<:Integer = lambertw_(float(z), k, maxits)
144
+ lambertw_(n::Irrational, k, maxits) = lambertw_(float(n), k, maxits)
164
145
165
146
# lambertw(e + 0im,k) is ok for all k
166
- #function lambertw(::Irrational{:e}, k::T) where T<:Integer
167
- function lambertw (::typeof(euler ), k::T) where T<:Integer
147
+ # Maybe this should return a float. But, this should cause no type instability in any case
148
+ function lambertw_ (::typeof(MathConstants.e ), k, maxits)
168
149
k == 0 && return 1
169
- @baddomain(k )
150
+ throw(DomainError(k) )
170
151
end
171
152
172
- # Maybe this should return a float
173
- lambertw(::typeof(euler)) = 1
174
- #lambertw(::Irrational{:e}) = 1
175
-
176
- #lambertw{T<:Number}(x::T) = lambertw(x,0)
177
- lambertw(x::Number) = lambertw(x,0)
178
-
179
- lambertw(n::Irrational, args::Integer...) = lambertw(float(n),args...)
180
-
181
153
### omega constant ###
182
154
183
155
const omega_const_ = 0.567143290409783872999968662210355
184
156
# The BigFloat `omega_const_bf_` is set via a literal in the function __init__ to prevent a segfault
185
157
186
158
# maybe compute higher precision. converges very quickly
187
159
function omega_const(::Type{BigFloat})
188
- @compat precision(BigFloat) <= 256 && return omega_const_bf_[]
160
+ precision(BigFloat) <= 256 && return omega_const_bf_[]
189
161
myeps = eps(BigFloat)
190
162
oc = omega_const_bf_[]
191
163
for i in 1:100
200
172
omega
201
173
ω
202
174
203
- A constant defined by `ω exp(ω) = 1`.
175
+ The constant defined by `ω exp(ω) = 1`.
204
176
205
177
```jldoctest
206
178
julia> ω
@@ -219,7 +191,7 @@ julia> big(omega)
219
191
const ω = Irrational{:ω}()
220
192
@doc (@doc ω) omega = ω
221
193
222
- # The following three lines may be removed when support for v0.6 is dropped
194
+ # The following two lines may be removed when support for v0.6 is dropped
223
195
Base.convert(::Type{AbstractFloat}, o::Irrational{:ω}) = Float64(o)
224
196
Base.convert(::Type{Float16}, o::Irrational{:ω}) = Float16(o)
225
197
Base.convert(::Type{T}, o::Irrational{:ω}) where T <:Number = T(o)
@@ -236,7 +208,7 @@ Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat)
236
208
# (4.23) and (4.24) for all μ are also given. This code implements the
237
209
# recursion relations.
238
210
239
- # (4.23) and (4.24) give zero based coefficients
211
+ # (4.23) and (4.24) give zero based coefficients.
240
212
cset(a,i,v) = a[i+1] = v
241
213
cget(a,i) = a[i+1]
242
214
@@ -247,7 +219,7 @@ function compa(k,m,a)
247
219
sum0 += cget(m,j) * cget(m,k+1-j)
248
220
end
249
221
cset(a,k,sum0)
250
- sum0
222
+ return sum0
251
223
end
252
224
253
225
# (4.23)
@@ -256,7 +228,7 @@ function compm(k,m,a)
256
228
mk = (kt-1)/(kt+1) *(cget(m,k-2)/2 + cget(a,k-2)/4) -
257
229
cget(a,k)/2 - cget(m,k-1)/(kt+1)
258
230
cset(m,k,mk)
259
- mk
231
+ return mk
260
232
end
261
233
262
234
# We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and
@@ -283,19 +255,21 @@ end
283
255
284
256
const LAMWMU_FLOAT64 = lamwcoeff(Float64,500)
285
257
286
- function horner(x, p::AbstractArray,n)
258
+ # Base.Math.@horner requires literal coefficients
259
+ # But, we have an array `p` of computed coefficients
260
+ function horner(x, p::AbstractArray, n)
287
261
n += 1
288
262
ex = p[n]
289
263
for i = n-1:-1:2
290
- ex = :($ (p[i]) + t * $ex )
264
+ ex = :(muladd(t, $ex, $ (p[i])) )
291
265
end
292
266
ex = :( t * $ex)
293
- Expr(:block, :(t = $x), ex)
267
+ return Expr(:block, :(t = $x), ex)
294
268
end
295
269
296
270
function mkwser(name, n)
297
271
iex = horner(:x,LAMWMU_FLOAT64,n)
298
- :(function ($name)(x) $iex end)
272
+ return :(function ($name)(x) $iex end)
299
273
end
300
274
301
275
eval(mkwser(:wser3, 3))
@@ -320,7 +294,7 @@ function wser(p,x)
320
294
x < 5e-2 && return wser32(p)
321
295
x < 1e-1 && return wser50(p)
322
296
x < 1.9e-1 && return wser100(p)
323
- x > 1/euler && @baddomain(x ) # radius of convergence
297
+ x > 1/MathConstants.e && throw(DomainError(x) ) # radius of convergence
324
298
return wser290(p) # good for x approx .32
325
299
end
326
300
@@ -335,28 +309,28 @@ function wser(p::Complex{T},z) where T<:Real
335
309
x < 5e-2 && return wser32(p)
336
310
x < 1e-1 && return wser50(p)
337
311
x < 1.9e-1 && return wser100(p)
338
- x > 1/euler && @baddomain(x ) # radius of convergence
312
+ x > 1/MathConstants.e && throw(DomainError(x) ) # radius of convergence
339
313
return wser290(p)
340
314
end
341
315
342
316
@inline function _lambertw0(x) # 1 + W(-1/e + x) , k = 0
343
- ps = 2*euler *x;
317
+ ps = 2*MathConstants.e *x;
344
318
p = sqrt(ps)
345
- wser(p,x)
319
+ return wser(p,x)
346
320
end
347
321
348
322
@inline function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1
349
- ps = 2*euler *x;
323
+ ps = 2*MathConstants.e *x;
350
324
p = -sqrt(ps)
351
- wser(p,x)
325
+ return wser(p,x)
352
326
end
353
327
354
328
"""
355
329
lambertwbp(z,k=0)
356
330
357
- Accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`.
358
- Accurate to Float64 precision for abs(z) < 0.32.
359
- If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. `lambertwbp` is vectorized.
331
+ Compute accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`.
332
+ The result is accurate to Float64 precision for abs(z) < 0.32.
333
+ If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned.
360
334
361
335
```jldoctest
362
336
julia> lambertw(-1/e + 1e-18, -1)
@@ -378,9 +352,7 @@ julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1))
378
352
function lambertwbp(x::Number,k::Integer)
379
353
k == 0 && return _lambertw0(x)
380
354
k == -1 && return _lambertwm1(x)
381
- error( "expansion about branch point only implemented for k = 0 and -1" )
355
+ throw(ArgumentError( "expansion about branch point only implemented for k = 0 and -1.") )
382
356
end
383
357
384
358
lambertwbp(x::Number) = _lambertw0(x)
385
-
386
- nothing
0 commit comments