Skip to content

Commit e530b4d

Browse files
jlapeyrealyst
authored andcommitted
made changes request in PR review
1 parent 01917cf commit e530b4d

File tree

4 files changed

+146
-156
lines changed

4 files changed

+146
-156
lines changed

LICENSE

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,35 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2222
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2323
SOFTWARE.
2424

25+
Portions of this code are derived from SciPy and are licensed under
26+
the Scipy License:
27+
28+
> Copyright (c) 2001, 2002 Enthought, Inc.
29+
> All rights reserved.
30+
31+
> Copyright (c) 2003-2012 SciPy Developers.
32+
> All rights reserved.
33+
34+
> Redistribution and use in source and binary forms, with or without
35+
> modification, are permitted provided that the following conditions are met:
36+
37+
> a. Redistributions of source code must retain the above copyright notice,
38+
> this list of conditions and the following disclaimer.
39+
> b. Redistributions in binary form must reproduce the above copyright
40+
> notice, this list of conditions and the following disclaimer in the
41+
> documentation and/or other materials provided with the distribution.
42+
> c. Neither the name of Enthought nor the names of the SciPy Developers
43+
> may be used to endorse or promote products derived from this software
44+
> without specific prior written permission.
45+
>
46+
> THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
47+
> AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
48+
> IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
49+
> ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS
50+
> BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
51+
> OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
52+
> SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
53+
> INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
54+
> CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
55+
> ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
56+
> THE POSSIBILITY OF SUCH DAMAGE.

src/SpecialFunctions.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ export
8282
lambertw,
8383
lambertwbp
8484

85+
const omega_const_bf_ = Ref{BigFloat}()
86+
function __init__()
87+
# allocate storage for this BigFloat constant each time this module is loaded
88+
omega_const_bf_[] =
89+
parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194")
90+
end
91+
8592
include("bessel.jl")
8693
include("erf.jl")
8794
include("ellip.jl")

src/lambertw.jl

Lines changed: 71 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,76 @@
11
import Base: convert
2-
#export lambertw, lambertwbp
3-
using Compat
4-
5-
const euler =
6-
if isdefined(Base, :MathConstants)
7-
Base.MathConstants.e
8-
else
9-
e
10-
end
112

12-
const omega_const_bf_ = Ref{BigFloat}()
13-
14-
function __init__()
15-
omega_const_bf_[] =
16-
parse(BigFloat,"0.5671432904097838729999686622103555497538157871865125081351310792230457930866845666932194")
17-
end
3+
using Compat
4+
import Compat.MathConstants # For clarity, we use MathConstants.e for Euler's number
185

196
#### Lambert W function ####
207

21-
const LAMBERTW_USE_NAN = false
22-
23-
macro baddomain(v)
24-
if LAMBERTW_USE_NAN
25-
return :(return(NaN))
26-
else
27-
return esc(:(throw(DomainError($v))))
28-
end
29-
end
30-
31-
# Use Halley's root-finding method to find x = lambertw(z) with
32-
# initial point x.
33-
function _lambertw(z::T, x::T) where T <: Number
8+
# Use Halley's root-finding method to find
9+
# x = lambertw(z) with initial point x.
10+
function _lambertw(z::T, x::T, maxits) where T <: Number
3411
two_t = convert(T,2)
3512
lastx = x
3613
lastdiff = zero(T)
37-
for i in 1:100
14+
converged::Bool = false
15+
for i in 1:maxits
3816
ex = exp(x)
3917
xexz = x * ex - z
4018
x1 = x + 1
41-
x = x - xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) )
19+
x -= xexz / (ex * x1 - (x + two_t) * xexz / (two_t * x1 ) )
4220
xdiff = abs(lastx - x)
43-
xdiff <= 2*eps(abs(lastx)) && break
44-
lastdiff == diff && break
21+
if xdiff <= 3*eps(abs(lastx)) || lastdiff == xdiff # second condition catches two-value cycle
22+
converged = true
23+
break
24+
end
4525
lastx = x
4626
lastdiff = xdiff
4727
end
48-
x
28+
converged || warn("lambertw with z=", z, " did not converge in ", maxits, " iterations.")
29+
return x
4930
end
5031

5132
### Real z ###
5233

5334
# Real x, k = 0
54-
35+
# This appears to be inferrable with T=Float64 and T=BigFloat, including if x=Inf.
5536
# The fancy initial condition selection does not seem to help speed, but we leave it for now.
56-
function lambertwk0(x::T)::T where T<:AbstractFloat
57-
x == Inf && return Inf
37+
function lambertwk0(x::T, maxits)::T where T<:AbstractFloat
38+
isnan(x) && return(NaN)
39+
x == Inf && return Inf # appears to return convert(BigFloat,Inf) for x == BigFloat(Inf)
5840
one_t = one(T)
59-
oneoe = -one_t/convert(T,euler)
41+
oneoe = -one_t/convert(T,MathConstants.e) # The branch point
6042
x == oneoe && return -one_t
43+
oneoe <= x || throw(DomainError(x))
6144
itwo_t = 1/convert(T,2)
62-
oneoe <= x || @baddomain(x)
6345
if x > one_t
6446
lx = log(x)
6547
llx = log(lx)
6648
x1 = lx - llx - log(one_t - llx/lx) * itwo_t
6749
else
6850
x1 = (567//1000) * x
6951
end
70-
_lambertw(x,x1)
52+
return _lambertw(x, x1, maxits)
7153
end
7254

7355
# Real x, k = -1
74-
function _lambertwkm1(x::T) where T<:Real
75-
oneoe = -one(T)/convert(T,euler)
76-
x == oneoe && return -one(T)
77-
oneoe <= x || @baddomain(x)
78-
x == zero(T) && return -convert(T,Inf)
79-
x < zero(T) || @baddomain(x)
80-
_lambertw(x,log(-x))
56+
function lambertwkm1(x::T, maxits) where T<:Real
57+
oneoe = -one(T)/convert(T,MathConstants.e)
58+
x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above
59+
oneoe <= x || throw(DomainError(x)) # branch domain exludes x < -1/e
60+
x == zero(T) && return -convert(T,Inf) # W decreases w/o bound as x -> 0 from below
61+
x < zero(T) || throw(DomainError(x))
62+
return _lambertw(x, log(-x), maxits)
8163
end
8264

83-
8465
"""
85-
lambertw(z::Complex{T}, k::V=0) where {T<:Real, V<:Integer}
86-
lambertw(z::T, k::V=0) where {T<:Real, V<:Integer}
66+
lambertw(z::Complex{T}, k::V=0, maxits=1000) where {T<:Real, V<:Integer}
67+
lambertw(z::T, k::V=0, maxits=1000) where {T<:Real, V<:Integer}
8768
8869
Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be
8970
either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e,0]` and the
9071
domain of the branch `k = 0` is `[-1/e,Inf]`. For `Complex` `z`, and all `k`, the domain is
91-
the complex plane.
72+
the complex plane. When using root finding to compute `W`, a value for `W` is returned
73+
with a warning if it has not converged after `maxits` iterations.
9274
9375
```jldoctest
9476
julia> lambertw(-1/e,-1)
@@ -107,33 +89,31 @@ julia> lambertw(Complex(-10.0,3.0), 4)
10789
-0.9274337508660128 + 26.37693445371142im
10890
```
10991
110-
!!! note
111-
The constant `LAMBERTW_USE_NAN` at the top of the source file controls whether arguments
112-
outside the domain throw `DomainError` or return `NaN`. The default is `DomainError`.
11392
"""
114-
function lambertw(x::Real, k::Integer)
115-
k == 0 && return lambertwk0(x)
116-
k == -1 && return _lambertwkm1(x)
117-
@baddomain(k) # more informative message like below ?
118-
# error("lambertw: real x must have k == 0 or k == -1")
93+
lambertw(z, k::Integer=0, maxits::Integer=1000) = lambertw_(z, k, maxits)
94+
95+
function lambertw_(x::Real, k, maxits)
96+
k == 0 && return lambertwk0(x, maxits)
97+
k == -1 && return lambertwkm1(x, maxits)
98+
throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1"))
11999
end
120100

121-
function lambertw(x::Union{Integer,Rational}, k::Integer)
101+
function lambertw_(x::Union{Integer,Rational}, k, maxits)
122102
if k == 0
123103
x == 0 && return float(zero(x))
124-
x == 1 && return convert(typeof(float(x)),omega) # must be more efficient way
104+
x == 1 && return convert(typeof(float(x)), omega) # must be a more efficient way
125105
end
126-
lambertw(float(x),k)
106+
return lambertw_(float(x), k, maxits)
127107
end
128108

129109
### Complex z ###
130110

131111
# choose initial value inside correct branch for root finding
132-
function lambertw(z::Complex{T}, k::Integer) where T<:Real
112+
function lambertw_(z::Complex{T}, k, maxits) where T<:Real
133113
one_t = one(T)
134114
local w::Complex{T}
135115
pointseven = 7//10
136-
if abs(z) <= one_t/convert(T,euler)
116+
if abs(z) <= one_t/convert(T,MathConstants.e)
137117
if z == 0
138118
k == 0 && return z
139119
return complex(-convert(T,Inf),zero(T))
@@ -157,35 +137,27 @@ function lambertw(z::Complex{T}, k::Integer) where T<:Real
157137
w = log(z)
158138
k != 0 ? w += complex(0, 2*k*pi) : nothing
159139
end
160-
_lambertw(z,w)
140+
return _lambertw(z, w, maxits)
161141
end
162142

163-
lambertw(z::Complex{T}, k::Integer) where T<:Integer = lambertw(float(z),k)
143+
lambertw_(z::Complex{T}, k, maxits) where T<:Integer = lambertw_(float(z), k, maxits)
144+
lambertw_(n::Irrational, k, maxits) = lambertw_(float(n), k, maxits)
164145

165146
# lambertw(e + 0im,k) is ok for all k
166-
#function lambertw(::Irrational{:e}, k::T) where T<:Integer
167-
function lambertw(::typeof(euler), k::T) where T<:Integer
147+
# Maybe this should return a float. But, this should cause no type instability in any case
148+
function lambertw_(::typeof(MathConstants.e), k, maxits)
168149
k == 0 && return 1
169-
@baddomain(k)
150+
throw(DomainError(k))
170151
end
171152

172-
# Maybe this should return a float
173-
lambertw(::typeof(euler)) = 1
174-
#lambertw(::Irrational{:e}) = 1
175-
176-
#lambertw{T<:Number}(x::T) = lambertw(x,0)
177-
lambertw(x::Number) = lambertw(x,0)
178-
179-
lambertw(n::Irrational, args::Integer...) = lambertw(float(n),args...)
180-
181153
### omega constant ###
182154

183155
const omega_const_ = 0.567143290409783872999968662210355
184156
# The BigFloat `omega_const_bf_` is set via a literal in the function __init__ to prevent a segfault
185157

186158
# maybe compute higher precision. converges very quickly
187159
function omega_const(::Type{BigFloat})
188-
@compat precision(BigFloat) <= 256 && return omega_const_bf_[]
160+
precision(BigFloat) <= 256 && return omega_const_bf_[]
189161
myeps = eps(BigFloat)
190162
oc = omega_const_bf_[]
191163
for i in 1:100
@@ -200,7 +172,7 @@ end
200172
omega
201173
ω
202174
203-
A constant defined by `ω exp(ω) = 1`.
175+
The constant defined by `ω exp(ω) = 1`.
204176
205177
```jldoctest
206178
julia> ω
@@ -219,7 +191,7 @@ julia> big(omega)
219191
const ω = Irrational{:ω}()
220192
@doc (@doc ω) omega = ω
221193

222-
# The following three lines may be removed when support for v0.6 is dropped
194+
# The following two lines may be removed when support for v0.6 is dropped
223195
Base.convert(::Type{AbstractFloat}, o::Irrational{:ω}) = Float64(o)
224196
Base.convert(::Type{Float16}, o::Irrational{:ω}) = Float16(o)
225197
Base.convert(::Type{T}, o::Irrational{:ω}) where T <:Number = T(o)
@@ -236,7 +208,7 @@ Base.BigFloat(o::Irrational{:ω}) = omega_const(BigFloat)
236208
# (4.23) and (4.24) for all μ are also given. This code implements the
237209
# recursion relations.
238210

239-
# (4.23) and (4.24) give zero based coefficients
211+
# (4.23) and (4.24) give zero based coefficients.
240212
cset(a,i,v) = a[i+1] = v
241213
cget(a,i) = a[i+1]
242214

@@ -247,7 +219,7 @@ function compa(k,m,a)
247219
sum0 += cget(m,j) * cget(m,k+1-j)
248220
end
249221
cset(a,k,sum0)
250-
sum0
222+
return sum0
251223
end
252224

253225
# (4.23)
@@ -256,7 +228,7 @@ function compm(k,m,a)
256228
mk = (kt-1)/(kt+1) *(cget(m,k-2)/2 + cget(a,k-2)/4) -
257229
cget(a,k)/2 - cget(m,k-1)/(kt+1)
258230
cset(m,k,mk)
259-
mk
231+
return mk
260232
end
261233

262234
# We plug the known value μ₂ == -1//3 for (4.22) into (4.23) and
@@ -283,19 +255,21 @@ end
283255

284256
const LAMWMU_FLOAT64 = lamwcoeff(Float64,500)
285257

286-
function horner(x, p::AbstractArray,n)
258+
# Base.Math.@horner requires literal coefficients
259+
# But, we have an array `p` of computed coefficients
260+
function horner(x, p::AbstractArray, n)
287261
n += 1
288262
ex = p[n]
289263
for i = n-1:-1:2
290-
ex = :($(p[i]) + t * $ex)
264+
ex = :(muladd(t, $ex, $(p[i])))
291265
end
292266
ex = :( t * $ex)
293-
Expr(:block, :(t = $x), ex)
267+
return Expr(:block, :(t = $x), ex)
294268
end
295269

296270
function mkwser(name, n)
297271
iex = horner(:x,LAMWMU_FLOAT64,n)
298-
:(function ($name)(x) $iex end)
272+
return :(function ($name)(x) $iex end)
299273
end
300274

301275
eval(mkwser(:wser3, 3))
@@ -320,7 +294,7 @@ function wser(p,x)
320294
x < 5e-2 && return wser32(p)
321295
x < 1e-1 && return wser50(p)
322296
x < 1.9e-1 && return wser100(p)
323-
x > 1/euler && @baddomain(x) # radius of convergence
297+
x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence
324298
return wser290(p) # good for x approx .32
325299
end
326300

@@ -335,28 +309,28 @@ function wser(p::Complex{T},z) where T<:Real
335309
x < 5e-2 && return wser32(p)
336310
x < 1e-1 && return wser50(p)
337311
x < 1.9e-1 && return wser100(p)
338-
x > 1/euler && @baddomain(x) # radius of convergence
312+
x > 1/MathConstants.e && throw(DomainError(x)) # radius of convergence
339313
return wser290(p)
340314
end
341315

342316
@inline function _lambertw0(x) # 1 + W(-1/e + x) , k = 0
343-
ps = 2*euler*x;
317+
ps = 2*MathConstants.e*x;
344318
p = sqrt(ps)
345-
wser(p,x)
319+
return wser(p,x)
346320
end
347321

348322
@inline function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1
349-
ps = 2*euler*x;
323+
ps = 2*MathConstants.e*x;
350324
p = -sqrt(ps)
351-
wser(p,x)
325+
return wser(p,x)
352326
end
353327

354328
"""
355329
lambertwbp(z,k=0)
356330
357-
Accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`.
358-
Accurate to Float64 precision for abs(z) < 0.32.
359-
If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned. `lambertwbp` is vectorized.
331+
Compute accurate value of `1 + W(-1/e + z)`, for `abs(z)` in `[0,1/e]` for `k` either `0` or `-1`.
332+
The result is accurate to Float64 precision for abs(z) < 0.32.
333+
If `k=-1` and `imag(z) < 0`, the value on the branch `k=1` is returned.
360334
361335
```jldoctest
362336
julia> lambertw(-1/e + 1e-18, -1)
@@ -378,9 +352,7 @@ julia> convert(Float64,(lambertw(-BigFloat(1)/e + BigFloat(10)^(-18),-1) + 1))
378352
function lambertwbp(x::Number,k::Integer)
379353
k == 0 && return _lambertw0(x)
380354
k == -1 && return _lambertwm1(x)
381-
error("expansion about branch point only implemented for k = 0 and -1")
355+
throw(ArgumentError("expansion about branch point only implemented for k = 0 and -1."))
382356
end
383357

384358
lambertwbp(x::Number) = _lambertw0(x)
385-
386-
nothing

0 commit comments

Comments
 (0)