Skip to content

Commit 90ee139

Browse files
committed
lambertw: annotate function args with types
1 parent cefddcf commit 90ee139

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

src/lambertw.jl

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#### Lambert W function ####
22

33
"""
4-
lambertw(z::Complex{T}, k::V=0, maxits=1000) where {T<:Real, V<:Integer}
5-
lambertw(z::T, k::V=0, maxits=1000) where {T<:Real, V<:Integer}
4+
lambertw(z::Number, k::Integer=0, maxits=1000)
65
76
Compute the `k`th branch of the Lambert W function of `z`. If `z` is real, `k` must be
87
either `0` or `-1`. For `Real` `z`, the domain of the branch `k = -1` is `[-1/e, 0]` and the
@@ -26,16 +25,16 @@ julia> lambertw(Complex(-10.0, 3.0), 4)
2625
-0.9274337508660128 + 26.37693445371142im
2726
```
2827
"""
29-
lambertw(z, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits)
28+
lambertw(z::Number, k::Integer=0, maxits::Integer=1000) = _lambertw(float(z), k, maxits)
3029

3130
# lambertw(e + 0im, k) is ok for all k
3231
# Maybe this should return a float. But, this should cause no type instability in any case
33-
function _lambertw(::typeof(MathConstants.e), k, maxits)
32+
function _lambertw(::typeof(MathConstants.e), k::Integer, maxits::Integer)
3433
k == 0 && return 1
3534
throw(DomainError(k))
3635
end
37-
_lambertw(x::Irrational, k, maxits) = _lambertw(float(x), k, maxits)
38-
function _lambertw(x::Union{Integer, Rational}, k, maxits)
36+
_lambertw(x::Irrational, k::Integer, maxits::Integer) = _lambertw(float(x), k, maxits)
37+
function _lambertw(x::Union{Integer, Rational}, k::Integer, maxits::Integer)
3938
if k == 0
4039
x == 0 && return float(zero(x))
4140
x == 1 && return convert(typeof(float(x)), omega) # must be a more efficient way
@@ -45,7 +44,7 @@ end
4544

4645
### Real x
4746

48-
function _lambertw(x::Real, k, maxits)
47+
function _lambertw(x::Real, k::Integer, maxits::Integer)
4948
k == 0 && return lambertw_branch_zero(x, maxits)
5049
k == -1 && return lambertw_branch_one(x, maxits)
5150
throw(DomainError(k, "lambertw: real x must have branch k == 0 or k == -1"))
@@ -54,7 +53,7 @@ end
5453
# Real x, k = 0
5554
# There is a magic number here. It could be noted, or possibly removed.
5655
# In particular, the fancy initial condition selection does not seem to help speed.
57-
function lambertw_branch_zero(x::T, maxits) where T<:Real
56+
function lambertw_branch_zero(x::T, maxits::Integer) where T<:Real
5857
isfinite(x) || return x
5958
one_t = one(T)
6059
oneoe = -inv(convert(T, MathConstants.e)) # The branch point
@@ -72,7 +71,7 @@ function lambertw_branch_zero(x::T, maxits) where T<:Real
7271
end
7372

7473
# Real x, k = -1
75-
function lambertw_branch_one(x::T, maxits) where T<:Real
74+
function lambertw_branch_one(x::T, maxits::Integer) where T<:Real
7675
oneoe = -inv(convert(T, MathConstants.e))
7776
x == oneoe && return -one(T) # W approaches -1 as x -> -1/e from above
7877
oneoe < x || throw(DomainError(x)) # branch domain exludes x < -1/e
@@ -83,10 +82,9 @@ end
8382

8483
### Complex z
8584

86-
_lambertw(z::Complex{<:Integer}, k, maxits) = _lambertw(float(z), k, maxits)
85+
_lambertw(z::Complex{<:Integer}, k::Integer, maxits::Integer) = _lambertw(float(z), k, maxits)
8786
# choose initial value inside correct branch for root finding
88-
function _lambertw(z::Complex{T}, k, maxits) where T<:Real
89-
one_t = one(T)
87+
function _lambertw(z::Complex{T}, k::Integer, maxits::Integer) where T<:Real
9088
local w::Complex{T}
9189
pointseven = 7//10
9290
if abs(z) <= inv(convert(T, MathConstants.e))
@@ -120,7 +118,7 @@ end
120118

121119
# Use Halley's root-finding method to find
122120
# x = lambertw(z) with initial point x0.
123-
function lambertw_root_finding(z::T, x0::T, maxits) where T <: Number
121+
function lambertw_root_finding(z::T, x0::T, maxits::Integer) where T <: Number
124122
two_t = convert(T, 2)
125123
x = x0
126124
lastx = x
@@ -203,7 +201,7 @@ Base.BigFloat(::Irrational{:ω}) = omega_const(BigFloat)
203201
# solve for α₂. We get α₂ = 0.
204202
# Compute array of coefficients μ in (4.22).
205203
# m[1] is μ₀
206-
function compute_branch_point_coeffs(T::DataType, n::Int)
204+
function compute_branch_point_coeffs(T::Type{<:Number}, n::Integer)
207205
a = Vector{T}(undef, n)
208206
m = Vector{T}(undef, n)
209207

@@ -259,7 +257,7 @@ end
259257
# Why is wser5 omitted ?
260258
# p is the argument to the series which is computed
261259
# from x before calling `branch_point_series`.
262-
function branch_point_series(p, x)
260+
function branch_point_series(p::Real, x::Real)
263261
x < 4e-11 && return wser3(p)
264262
x < 1e-5 && return wser7(p)
265263
x < 1e-3 && return wser12(p)
@@ -273,7 +271,7 @@ function branch_point_series(p, x)
273271
end
274272

275273
# These may need tuning.
276-
function branch_point_series(p::Complex{T}, z) where T<:Real
274+
function branch_point_series(p::Complex{T}, z::Complex{T}) where T<:Real
277275
x = abs(z)
278276
x < 4e-11 && return wser3(p)
279277
x < 1e-5 && return wser7(p)
@@ -287,13 +285,13 @@ function branch_point_series(p::Complex{T}, z) where T<:Real
287285
return wser290(p)
288286
end
289287

290-
function _lambertw0(x) # 1 + W(-1/e + x) , k = 0
288+
function _lambertw0(x::Number) # 1 + W(-1/e + x) , k = 0
291289
ps = 2 * MathConstants.e * x
292290
series_arg = sqrt(ps)
293291
branch_point_series(series_arg, x)
294292
end
295293

296-
function _lambertwm1(x) # 1 + W(-1/e + x) , k = -1
294+
function _lambertwm1(x::Number) # 1 + W(-1/e + x) , k = -1
297295
ps = 2 * MathConstants.e * x
298296
series_arg = -sqrt(ps)
299297
branch_point_series(series_arg, x)

0 commit comments

Comments
 (0)