Skip to content

Commit 2c5df13

Browse files
authored
Merge pull request #18 from JuliaMath/more-base
More things from Base
2 parents 71812c3 + a1ef500 commit 2c5df13

File tree

2 files changed

+139
-60
lines changed

2 files changed

+139
-60
lines changed

src/gamma.jl

Lines changed: 107 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# This file contains code that was formerly a part of Julia. License is MIT: http://julialang.org/license
22

3-
using Base.Math: signflip, f16, f32, f64
43
using Base.MPFR: ROUNDING_MODE, big_ln2
54

5+
typealias ComplexOrReal{T} Union{T,Complex{T}}
6+
67
# Bernoulli numbers B_{2k}, using tabulated numerators and denominators from
78
# the online encyclopedia of integer sequences. (They actually have data
89
# up to k=249, but we stop here at k=20.) Used for generating the polygamma
@@ -16,7 +17,7 @@ using Base.MPFR: ROUNDING_MODE, big_ln2
1617
1718
Compute the digamma function of `x` (the logarithmic derivative of `gamma(x)`).
1819
"""
19-
function digamma(z::Union{Float64,Complex{Float64}})
20+
function digamma(z::ComplexOrReal{Float64})
2021
# Based on eq. (12), without looking at the accompanying source
2122
# code, of: K. S. Kölbig, "Programs for computing the logarithm of
2223
# the gamma function, and the digamma function, for complex
@@ -56,7 +57,7 @@ end
5657
5758
Compute the trigamma function of `x` (the logarithmic second derivative of `gamma(x)`).
5859
"""
59-
function trigamma(z::Union{Float64,Complex{Float64}})
60+
function trigamma(z::ComplexOrReal{Float64})
6061
# via the derivative of the Kölbig digamma formulation
6162
x = real(z)
6263
if x <= 0 # reflection formula
@@ -79,6 +80,9 @@ function trigamma(z::Union{Float64,Complex{Float64}})
7980
ψ += t*w * @evalpoly(w,0.16666666666666666,-0.03333333333333333,0.023809523809523808,-0.03333333333333333,0.07575757575757576,-0.2531135531135531,1.1666666666666667,-7.092156862745098)
8081
end
8182

83+
signflip(m::Number, z) = (-1+0im)^m * z
84+
signflip(m::Integer, z) = iseven(m) ? z : -z
85+
8286
# (-1)^m d^m/dz^m cot(z) = p_m(cot z), where p_m is a polynomial
8387
# that satisfies the recurrence p_{m+1}(x) = p_m′(x) * (1 + x^2).
8488
# Note that p_m(x) has only even powers of x if m is odd, and
@@ -213,8 +217,7 @@ this definition is equivalent to the Hurwitz zeta function
213217
``\\sum_{k=0}^\\infty (k+z)^{-s}``. For ``z=1``, it yields
214218
the Riemann zeta function ``\\zeta(s)``.
215219
"""
216-
function zeta(s::Union{Int,Float64,Complex{Float64}},
217-
z::Union{Float64,Complex{Float64}})
220+
function zeta(s::ComplexOrReal{Float64}, z::ComplexOrReal{Float64})
218221
ζ = zero(promote_type(typeof(s), typeof(z)))
219222

220223
(z == 1 || z == 0) && return oftype(ζ, zeta(s))
@@ -263,7 +266,8 @@ function zeta(s::Union{Int,Float64,Complex{Float64}},
263266
minus_z = -z
264267
ζ += pow_oftype(ζ, minus_z, minus_s) # ν = 0 term
265268
if xf != z
266-
ζ += pow_oftype(ζ, z - nx, minus_s) # real(z - nx) > 0, so use correct branch cut
269+
ζ += pow_oftype(ζ, z - nx, minus_s)
270+
# real(z - nx) > 0, so use correct branch cut
267271
# otherwise, if xf==z, then the definition skips this term
268272
end
269273
# do loop in different order, depending on the sign of s,
@@ -316,10 +320,10 @@ end
316320
"""
317321
polygamma(m, x)
318322
319-
Compute the polygamma function of order `m` of argument `x` (the `(m+1)th` derivative of the
323+
Compute the polygamma function of order `m` of argument `x` (the `(m+1)`th derivative of the
320324
logarithm of `gamma(x)`)
321325
"""
322-
function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
326+
function polygamma(m::Integer, z::ComplexOrReal{Float64})
323327
m == 0 && return digamma(z)
324328
m == 1 && return trigamma(z)
325329

@@ -337,40 +341,25 @@ function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
337341
# constants. We throw a DomainError() since the definition is unclear.
338342
real(m) < 0 && throw(DomainError())
339343

340-
s = m+1
344+
s = Float64(m+1)
345+
# It is safe to convert any integer (including `BigInt`) to a float here
346+
# as underflow occurs before precision issues.
341347
if real(z) <= 0 # reflection formula
342348
(zeta(s, 1-z) + signflip(m, cotderiv(m,z))) * (-gamma(s))
343349
else
344350
signflip(m, zeta(s,z) * (-gamma(s)))
345351
end
346352
end
347353

348-
# If we really cared about single precision, we could make a faster
349-
# Float32 version by truncating the Stirling series at a smaller cutoff.
350-
for (f,T) in ((:f32,Float32),(:f16,Float16))
351-
@eval begin
352-
zeta(s::Integer, z::Union{$T,Complex{$T}}) = $f(zeta(Int(s), f64(z)))
353-
zeta(s::Union{Float64,Complex128}, z::Union{$T,Complex{$T}}) = zeta(s, f64(z))
354-
zeta(s::Number, z::Union{$T,Complex{$T}}) = $f(zeta(f64(s), f64(z)))
355-
polygamma(m::Integer, z::Union{$T,Complex{$T}}) = $f(polygamma(Int(m), f64(z)))
356-
digamma(z::Union{$T,Complex{$T}}) = $f(digamma(f64(z)))
357-
trigamma(z::Union{$T,Complex{$T}}) = $f(trigamma(f64(z)))
358-
end
359-
end
360-
361-
zeta(s::Integer, z::Number) = zeta(Int(s), f64(z))
362-
zeta(s::Number, z::Number) = zeta(f64(s), f64(z))
363-
for f in (:digamma, :trigamma)
364-
@eval begin
365-
$f(z::Number) = $f(f64(z))
366-
end
367-
end
368-
polygamma(m::Integer, z::Number) = polygamma(m, f64(z))
354+
"""
355+
invdigamma(x)
369356
370-
# Inverse digamma function:
371-
# Implementation of fixed point algorithm described in
372-
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000
357+
Compute the inverse [`digamma`](@ref) function of `x`.
358+
"""
373359
function invdigamma(y::Float64)
360+
# Implementation of fixed point algorithm described in
361+
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000
362+
374363
# Closed form initial estimates
375364
if y >= -2.22
376365
x_old = exp(y) + 0.5
@@ -392,18 +381,16 @@ function invdigamma(y::Float64)
392381

393382
return x_new
394383
end
395-
invdigamma(x::Float32) = Float32(invdigamma(Float64(x)))
396384

397385
"""
398-
invdigamma(x)
386+
zeta(s)
399387
400-
Compute the inverse [`digamma`](@ref) function of `x`.
388+
Riemann zeta function ``\\zeta(s)``.
401389
"""
402-
invdigamma(x::Real) = invdigamma(Float64(x))
390+
function zeta(s::ComplexOrReal{Float64})
391+
# Riemann zeta function; algorithm is based on specializing the Hurwitz
392+
# zeta function above for z==1.
403393

404-
# Riemann zeta function; algorithm is based on specializing the Hurwitz
405-
# zeta function above for z==1.
406-
function zeta(s::Union{Float64,Complex{Float64}})
407394
# blows up to ±Inf, but get correct sign of imaginary zero
408395
s == 1 && return NaN + zero(s) * imag(s)
409396

@@ -448,23 +435,18 @@ function zeta(s::Union{Float64,Complex{Float64}})
448435
return ζ
449436
end
450437

451-
zeta(x::Integer) = zeta(Float64(x))
452-
zeta(x::Real) = oftype(float(x),zeta(Float64(x)))
453-
454-
"""
455-
zeta(s)
456-
457-
Riemann zeta function ``\\zeta(s)``.
458-
"""
459-
zeta(z::Complex) = oftype(float(z),zeta(Complex128(z)))
460-
461438
function zeta(x::BigFloat)
462439
z = BigFloat()
463440
ccall((:mpfr_zeta, :libmpfr), Int32, (Ptr{BigFloat}, Ptr{BigFloat}, Int32), &z, &x, ROUNDING_MODE[])
464441
return z
465442
end
466443

467-
function eta(z::Union{Float64,Complex{Float64}})
444+
"""
445+
eta(x)
446+
447+
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``.
448+
"""
449+
function eta(z::ComplexOrReal{Float64})
468450
δz = 1 - z
469451
if abs(real(δz)) + abs(imag(δz)) < 7e-3 # Taylor expand around z==1
470452
return 0.6931471805599453094172321214581765 *
@@ -478,17 +460,82 @@ function eta(z::Union{Float64,Complex{Float64}})
478460
return -zeta(z) * expm1(0.6931471805599453094172321214581765*δz)
479461
end
480462
end
481-
eta(x::Integer) = eta(Float64(x))
482-
eta(x::Real) = oftype(float(x),eta(Float64(x)))
483-
484-
"""
485-
eta(x)
486-
487-
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``.
488-
"""
489-
eta(z::Complex) = oftype(float(z),eta(Complex128(z)))
490463

491464
function eta(x::BigFloat)
492465
x == 1 && return big_ln2()
493466
return -zeta(x) * expm1(big_ln2()*(1-x))
494467
end
468+
469+
# Converting types that we can convert, and not ones we can not
470+
# Float16, and Float32 and their Complex equivalents can be converted to Float64
471+
# and results converted back.
472+
# Otherwise, we need to make things use their own `float` converting methods
473+
# and in those cases, we do not convert back either as we assume
474+
# they also implement their own versions of the functions, with the correct return types.
475+
# This is the case for BitIntegers (which become `Float64` when `float`ed).
476+
# Otherwise, if they do not implement their version of the functions we
477+
# manually throw a `MethodError`.
478+
# This case occurs, when calling `float` on a type does not change its type,
479+
# as it is already a `float`, and would have hit own method, if one had existed.
480+
481+
482+
# If we really cared about single precision, we could make a faster
483+
# Float32 version by truncating the Stirling series at a smaller cutoff.
484+
# and if we really cared about half precision, we could make a faster
485+
# Float16 version, by using a precomputed table look-up.
486+
487+
488+
for T in (Float16, Float32, Float64)
489+
@eval f64(x::Complex{$T}) = Complex128(x)
490+
@eval f64(x::$T) = Float64(x)
491+
end
492+
493+
494+
for f in (:digamma, :trigamma, :zeta, :eta, :invdigamma)
495+
@eval begin
496+
function $f(z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}})
497+
oftype(z, $f(f64(z)))
498+
end
499+
500+
function $f(z::Number)
501+
x = float(z)
502+
typeof(x) === typeof(z) && throw(MethodError($f, (z,)))
503+
# There is nothing to fallback to, as this didn't change the argument types
504+
$f(x)
505+
end
506+
end
507+
end
508+
509+
510+
for T1 in (Float16, Float32, Float64), T2 in (Float16, Float32, Float64)
511+
(T1 == T2 == Float64) && continue # Avoid redefining base definition
512+
513+
@eval function zeta(s::ComplexOrReal{$T1}, z::ComplexOrReal{$T2})
514+
ζ = zeta(f64(s), f64(z))
515+
convert(promote_type(typeof(s), typeof(z)), ζ)
516+
end
517+
end
518+
519+
520+
function zeta(s::Number, z::Number)
521+
t = float(s)
522+
x = float(z)
523+
if typeof(t) === typeof(s) && typeof(x) === typeof(z)
524+
# There is nothing to fallback to, since this didn't work
525+
throw(MethodError(zeta,(s,z)))
526+
end
527+
zeta(t, x)
528+
end
529+
530+
531+
function polygamma(m::Integer, z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}})
532+
oftype(z, polygamma(m, f64(z)))
533+
end
534+
535+
536+
function polygamma(m::Integer, z::Number)
537+
x = float(z)
538+
typeof(x) === typeof(z) && throw(MethodError(polygamma, (m,z)))
539+
# There is nothing to fallback to, since this didn't work
540+
polygamma(m, x)
541+
end

test/runtests.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,35 @@ end
454454
@test typeof(SF.erfc(a)) == BigFloat
455455
end
456456
end
457+
458+
@testset "Base Julia issue #17474" begin
459+
@test SF.f64(complex(1f0,1f0)) === complex(1.0, 1.0)
460+
@test SF.f64(1f0) === 1.0
461+
462+
@test typeof(SF.eta(big"2")) == BigFloat
463+
@test typeof(SF.zeta(big"2")) == BigFloat
464+
@test typeof(SF.digamma(big"2")) == BigFloat
465+
466+
@test_throws MethodError SF.trigamma(big"2")
467+
@test_throws MethodError SF.trigamma(big"2.0")
468+
@test_throws MethodError SF.invdigamma(big"2")
469+
@test_throws MethodError SF.invdigamma(big"2.0")
470+
471+
@test_throws MethodError SF.eta(Complex(big"2"))
472+
@test_throws MethodError SF.eta(Complex(big"2.0"))
473+
@test_throws MethodError SF.zeta(Complex(big"2"))
474+
@test_throws MethodError SF.zeta(Complex(big"2.0"))
475+
@test_throws MethodError SF.zeta(1.0,big"2")
476+
@test_throws MethodError SF.zeta(1.0,big"2.0")
477+
@test_throws MethodError SF.zeta(big"1.0",2.0)
478+
@test_throws MethodError SF.zeta(big"1",2.0)
479+
480+
481+
@test typeof(SF.polygamma(3, 0x2)) == Float64
482+
@test typeof(SF.polygamma(big"3", 2f0)) == Float32
483+
@test typeof(SF.zeta(1, 2.0)) == Float64
484+
@test typeof(SF.zeta(1, 2f0)) == Float64 # BitIntegers result in Float64 returns
485+
@test typeof(SF.zeta(2f0, complex(2f0,0f0))) == Complex{Float32}
486+
@test typeof(SF.zeta(complex(1,1), 2f0)) == Complex{Float64}
487+
@test typeof(SF.zeta(complex(1), 2.0)) == Complex{Float64}
488+
end

0 commit comments

Comments
 (0)