@@ -200,7 +200,7 @@ Base.BigFloat(::Irrational{:ω}) = omega_const(BigFloat)
200
200
# solve for α₂. We get α₂ = 0.
201
201
# Compute array of coefficients μ in (4.22).
202
202
# m[1] is μ₀
203
- function compute_branch_point_coeffs (T:: Type{<:Number} , n:: Integer )
203
+ function lambertw_coeffs (T:: Type{<:Number} , n:: Integer )
204
204
a = Vector {T} (undef, n)
205
205
m = Vector {T} (undef, n)
206
206
@@ -227,68 +227,43 @@ function compute_branch_point_coeffs(T::Type{<:Number}, n::Integer)
227
227
return m
228
228
end
229
229
230
- const BRANCH_POINT_COEFFS_FLOAT64 = compute_branch_point_coeffs (Float64, 500 )
230
+ const LAMBERTW_COEFFS_FLOAT64 = lambertw_coeffs (Float64, 500 )
231
231
232
- # Base.Math.@horner requires literal coefficients
233
- # It cannot be used here because we have an array of computed coefficients
234
- function horner (x, coeffs:: AbstractArray , n)
235
- n += 1
236
- ex = coeffs[n]
237
- for i = (n - 1 ): - 1 : 2
238
- ex = :(muladd (t, $ ex, $ (coeffs[i])))
239
- end
240
- ex = :( t * $ ex)
241
- return Expr (:block , :(t = $ x), ex)
242
- end
243
-
244
- # write functions that evaluate the branch point series
245
- # with `num_terms` number of terms.
246
- for (func_name, num_terms) in (
247
- (:wser3 , 3 ), (:wser5 , 5 ), (:wser7 , 7 ), (:wser12 , 12 ),
248
- (:wser19 , 19 ), (:wser26 , 26 ), (:wser32 , 32 ),
249
- (:wser50 , 50 ), (:wser100 , 100 ), (:wser290 , 290 ))
250
- iex = horner (:x , BRANCH_POINT_COEFFS_FLOAT64, num_terms)
251
- @eval function ($ func_name)(x) $ iex end
252
- end
232
+ (lambertwbp_evalpoly (x:: T , :: Val{N} ):: T ) where {T<: Number , N} =
233
+ # assume that Julia compiler is smart to decide for which N to unroll at compile time
234
+ # note that we skip μ₀=-1
235
+ evalpoly (x, ntuple (i -> LAMBERTW_COEFFS_FLOAT64[i+ 1 ], N- 1 ))* x
253
236
254
- # Converges to Float64 precision
237
+ # how many coefficients of the series to use
238
+ # to converge to Float64 precision for given x
255
239
# We could get finer tuning by separating k=0, -1 branches.
256
- # Why is wser5 omitted ?
257
- # p is the argument to the series which is computed
258
- # from x before calling `branch_point_series`.
259
- function branch_point_series (p:: Real , x:: Real )
260
- x < 4e-11 && return wser3 (p)
261
- x < 1e-5 && return wser7 (p)
262
- x < 1e-3 && return wser12 (p)
263
- x < 1e-2 && return wser19 (p)
264
- x < 3e-2 && return wser26 (p)
265
- x < 5e-2 && return wser32 (p)
266
- x < 1e-1 && return wser50 (p)
267
- x < 1.9e-1 && return wser100 (p)
268
- x > 1 / MathConstants. e && throw (DomainError (x)) # radius of convergence
269
- return wser290 (p) # good for x approx .32
240
+ function lambertwbp_series_length (x:: Real )
241
+ x < 4e-11 && return 3
242
+ # Why N = 5 is omitted?
243
+ x < 1e-5 && return 7
244
+ x < 1e-3 && return 12
245
+ x < 1e-2 && return 19
246
+ x < 3e-2 && return 26
247
+ x < 5e-2 && return 32
248
+ x < 1e-1 && return 50
249
+ x < 1.9e-1 && return 100
250
+ x > inv (MathConstants. e) && throw (DomainError (x)) # radius of convergence
251
+ return 290 # good for x approx .32
270
252
end
271
253
272
254
# These may need tuning.
273
- function branch_point_series (p:: Complex{T} , z:: Complex{T} ) where T<: Real
274
- x = abs (z)
275
- x < 4e-11 && return wser3 (p)
276
- x < 1e-5 && return wser7 (p)
277
- x < 1e-3 && return wser12 (p)
278
- x < 1e-2 && return wser19 (p)
279
- x < 3e-2 && return wser26 (p)
280
- x < 5e-2 && return wser32 (p)
281
- x < 1e-1 && return wser50 (p)
282
- x < 1.9e-1 && return wser100 (p)
283
- x > 1 / MathConstants. e && throw (DomainError (x)) # radius of convergence
284
- return wser290 (p)
285
- end
255
+ lambertwbp_series_length (z:: Complex ) = lambertwbp_series_length (abs (z))
256
+
257
+ # p is the argument to the series which is computed from x,
258
+ # see `_lambertwbp()`.
259
+ lambertwbp_series (p:: Number , x:: Number ) =
260
+ lambertwbp_evalpoly (p, Val {lambertwbp_series_length(x)} ())
286
261
287
262
_lambertwbp (x:: Number , :: Val{0} ) =
288
- branch_point_series (sqrt (2 * MathConstants. e * x), x)
263
+ lambertwbp_series (sqrt (2 * MathConstants. e * x), x)
289
264
290
265
_lambertwbp (x:: Number , :: Val{-1} ) =
291
- branch_point_series (- sqrt (2 * MathConstants. e * x), x)
266
+ lambertwbp_series (- sqrt (2 * MathConstants. e * x), x)
292
267
293
268
_lambertwbp (_:: Number , k:: Val ) =
294
269
throw (ArgumentError (" lambertw() expansion about branch point for k=$k not implemented (only implemented for 0 and -1)." ))
0 commit comments