Skip to content

Commit 5466fb8

Browse files
oxinaboxararslan
authored andcommitted
Improve type consistency of special functions
1 parent 71812c3 commit 5466fb8

File tree

2 files changed

+136
-59
lines changed

2 files changed

+136
-59
lines changed

src/gamma.jl

Lines changed: 104 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
using Base.Math: signflip, f16, f32, f64
44
using Base.MPFR: ROUNDING_MODE, big_ln2
55

6+
typealias ComplexOrReal{T} Union{T,Complex{T}}
7+
68
# Bernoulli numbers B_{2k}, using tabulated numerators and denominators from
79
# the online encyclopedia of integer sequences. (They actually have data
810
# up to k=249, but we stop here at k=20.) Used for generating the polygamma
@@ -16,7 +18,7 @@ using Base.MPFR: ROUNDING_MODE, big_ln2
1618
1719
Compute the digamma function of `x` (the logarithmic derivative of `gamma(x)`).
1820
"""
19-
function digamma(z::Union{Float64,Complex{Float64}})
21+
function digamma(z::ComplexOrReal{Float64})
2022
# Based on eq. (12), without looking at the accompanying source
2123
# code, of: K. S. Kölbig, "Programs for computing the logarithm of
2224
# the gamma function, and the digamma function, for complex
@@ -56,7 +58,7 @@ end
5658
5759
Compute the trigamma function of `x` (the logarithmic second derivative of `gamma(x)`).
5860
"""
59-
function trigamma(z::Union{Float64,Complex{Float64}})
61+
function trigamma(z::ComplexOrReal{Float64})
6062
# via the derivative of the Kölbig digamma formulation
6163
x = real(z)
6264
if x <= 0 # reflection formula
@@ -213,8 +215,7 @@ this definition is equivalent to the Hurwitz zeta function
213215
``\\sum_{k=0}^\\infty (k+z)^{-s}``. For ``z=1``, it yields
214216
the Riemann zeta function ``\\zeta(s)``.
215217
"""
216-
function zeta(s::Union{Int,Float64,Complex{Float64}},
217-
z::Union{Float64,Complex{Float64}})
218+
function zeta(s::ComplexOrReal{Float64}, z::ComplexOrReal{Float64})
218219
ζ = zero(promote_type(typeof(s), typeof(z)))
219220

220221
(z == 1 || z == 0) && return oftype(ζ, zeta(s))
@@ -263,7 +264,8 @@ function zeta(s::Union{Int,Float64,Complex{Float64}},
263264
minus_z = -z
264265
ζ += pow_oftype(ζ, minus_z, minus_s) # ν = 0 term
265266
if xf != z
266-
ζ += pow_oftype(ζ, z - nx, minus_s) # real(z - nx) > 0, so use correct branch cut
267+
ζ += pow_oftype(ζ, z - nx, minus_s)
268+
# real(z - nx) > 0, so use correct branch cut
267269
# otherwise, if xf==z, then the definition skips this term
268270
end
269271
# do loop in different order, depending on the sign of s,
@@ -316,10 +318,10 @@ end
316318
"""
317319
polygamma(m, x)
318320
319-
Compute the polygamma function of order `m` of argument `x` (the `(m+1)th` derivative of the
321+
Compute the polygamma function of order `m` of argument `x` (the `(m+1)`th derivative of the
320322
logarithm of `gamma(x)`)
321323
"""
322-
function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
324+
function polygamma(m::Integer, z::ComplexOrReal{Float64})
323325
m == 0 && return digamma(z)
324326
m == 1 && return trigamma(z)
325327

@@ -337,40 +339,25 @@ function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
337339
# constants. We throw a DomainError() since the definition is unclear.
338340
real(m) < 0 && throw(DomainError())
339341

340-
s = m+1
342+
s = Float64(m+1)
343+
# It is safe to convert any integer (including `BigInt`) to a float here
344+
# as underflow occurs before precision issues.
341345
if real(z) <= 0 # reflection formula
342346
(zeta(s, 1-z) + signflip(m, cotderiv(m,z))) * (-gamma(s))
343347
else
344348
signflip(m, zeta(s,z) * (-gamma(s)))
345349
end
346350
end
347351

348-
# If we really cared about single precision, we could make a faster
349-
# Float32 version by truncating the Stirling series at a smaller cutoff.
350-
for (f,T) in ((:f32,Float32),(:f16,Float16))
351-
@eval begin
352-
zeta(s::Integer, z::Union{$T,Complex{$T}}) = $f(zeta(Int(s), f64(z)))
353-
zeta(s::Union{Float64,Complex128}, z::Union{$T,Complex{$T}}) = zeta(s, f64(z))
354-
zeta(s::Number, z::Union{$T,Complex{$T}}) = $f(zeta(f64(s), f64(z)))
355-
polygamma(m::Integer, z::Union{$T,Complex{$T}}) = $f(polygamma(Int(m), f64(z)))
356-
digamma(z::Union{$T,Complex{$T}}) = $f(digamma(f64(z)))
357-
trigamma(z::Union{$T,Complex{$T}}) = $f(trigamma(f64(z)))
358-
end
359-
end
360-
361-
zeta(s::Integer, z::Number) = zeta(Int(s), f64(z))
362-
zeta(s::Number, z::Number) = zeta(f64(s), f64(z))
363-
for f in (:digamma, :trigamma)
364-
@eval begin
365-
$f(z::Number) = $f(f64(z))
366-
end
367-
end
368-
polygamma(m::Integer, z::Number) = polygamma(m, f64(z))
352+
"""
353+
invdigamma(x)
369354
370-
# Inverse digamma function:
371-
# Implementation of fixed point algorithm described in
372-
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000
355+
Compute the inverse [`digamma`](@ref) function of `x`.
356+
"""
373357
function invdigamma(y::Float64)
358+
# Implementation of fixed point algorithm described in
359+
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000
360+
374361
# Closed form initial estimates
375362
if y >= -2.22
376363
x_old = exp(y) + 0.5
@@ -392,18 +379,16 @@ function invdigamma(y::Float64)
392379

393380
return x_new
394381
end
395-
invdigamma(x::Float32) = Float32(invdigamma(Float64(x)))
396382

397383
"""
398-
invdigamma(x)
384+
zeta(s)
399385
400-
Compute the inverse [`digamma`](@ref) function of `x`.
386+
Riemann zeta function ``\\zeta(s)``.
401387
"""
402-
invdigamma(x::Real) = invdigamma(Float64(x))
388+
function zeta(s::ComplexOrReal{Float64})
389+
# Riemann zeta function; algorithm is based on specializing the Hurwitz
390+
# zeta function above for z==1.
403391

404-
# Riemann zeta function; algorithm is based on specializing the Hurwitz
405-
# zeta function above for z==1.
406-
function zeta(s::Union{Float64,Complex{Float64}})
407392
# blows up to ±Inf, but get correct sign of imaginary zero
408393
s == 1 && return NaN + zero(s) * imag(s)
409394

@@ -448,23 +433,18 @@ function zeta(s::Union{Float64,Complex{Float64}})
448433
return ζ
449434
end
450435

451-
zeta(x::Integer) = zeta(Float64(x))
452-
zeta(x::Real) = oftype(float(x),zeta(Float64(x)))
453-
454-
"""
455-
zeta(s)
456-
457-
Riemann zeta function ``\\zeta(s)``.
458-
"""
459-
zeta(z::Complex) = oftype(float(z),zeta(Complex128(z)))
460-
461436
function zeta(x::BigFloat)
462437
z = BigFloat()
463438
ccall((:mpfr_zeta, :libmpfr), Int32, (Ptr{BigFloat}, Ptr{BigFloat}, Int32), &z, &x, ROUNDING_MODE[])
464439
return z
465440
end
466441

467-
function eta(z::Union{Float64,Complex{Float64}})
442+
"""
443+
eta(x)
444+
445+
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``.
446+
"""
447+
function eta(z::ComplexOrReal{Float64})
468448
δz = 1 - z
469449
if abs(real(δz)) + abs(imag(δz)) < 7e-3 # Taylor expand around z==1
470450
return 0.6931471805599453094172321214581765 *
@@ -478,17 +458,82 @@ function eta(z::Union{Float64,Complex{Float64}})
478458
return -zeta(z) * expm1(0.6931471805599453094172321214581765*δz)
479459
end
480460
end
481-
eta(x::Integer) = eta(Float64(x))
482-
eta(x::Real) = oftype(float(x),eta(Float64(x)))
483-
484-
"""
485-
eta(x)
486-
487-
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``.
488-
"""
489-
eta(z::Complex) = oftype(float(z),eta(Complex128(z)))
490461

491462
function eta(x::BigFloat)
492463
x == 1 && return big_ln2()
493464
return -zeta(x) * expm1(big_ln2()*(1-x))
494465
end
466+
467+
# Converting types that we can convert, and not ones we can not
468+
# Float16, and Float32 and their Complex equivalents can be converted to Float64
469+
# and results converted back.
470+
# Otherwise, we need to make things use their own `float` converting methods
471+
# and in those cases, we do not convert back either as we assume
472+
# they also implement their own versions of the functions, with the correct return types.
473+
# This is the case for BitIntegers (which become `Float64` when `float`ed).
474+
# Otherwise, if they do not implement their version of the functions we
475+
# manually throw a `MethodError`.
476+
# This case occurs, when calling `float` on a type does not change its type,
477+
# as it is already a `float`, and would have hit own method, if one had existed.
478+
479+
480+
# If we really cared about single precision, we could make a faster
481+
# Float32 version by truncating the Stirling series at a smaller cutoff.
482+
# and if we really cared about half precision, we could make a faster
483+
# Float16 version, by using a precomputed table look-up.
484+
485+
486+
for T in (Float16, Float32, Float64)
487+
@eval f64(x::Complex{$T}) = Complex128(x)
488+
@eval f64(x::$T) = Float64(x)
489+
end
490+
491+
492+
for f in (:digamma, :trigamma, :zeta, :eta, :invdigamma)
493+
@eval begin
494+
function $f(z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}})
495+
oftype(z, $f(f64(z)))
496+
end
497+
498+
function $f(z::Number)
499+
x = float(z)
500+
typeof(x) === typeof(z) && throw(MethodError($f, (z,)))
501+
# There is nothing to fallback to, as this didn't change the argument types
502+
$f(x)
503+
end
504+
end
505+
end
506+
507+
508+
for T1 in (Float16, Float32, Float64), T2 in (Float16, Float32, Float64)
509+
(T1 == T2 == Float64) && continue # Avoid redefining base definition
510+
511+
@eval function zeta(s::ComplexOrReal{$T1}, z::ComplexOrReal{$T2})
512+
ζ = zeta(f64(s), f64(z))
513+
convert(promote_type(typeof(s), typeof(z)), ζ)
514+
end
515+
end
516+
517+
518+
function zeta(s::Number, z::Number)
519+
t = float(s)
520+
x = float(z)
521+
if typeof(t) === typeof(s) && typeof(x) === typeof(z)
522+
# There is nothing to fallback to, since this didn't work
523+
throw(MethodError(zeta,(s,z)))
524+
end
525+
zeta(t, x)
526+
end
527+
528+
529+
function polygamma(m::Integer, z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}})
530+
oftype(z, polygamma(m, f64(z)))
531+
end
532+
533+
534+
function polygamma(m::Integer, z::Number)
535+
x = float(z)
536+
typeof(x) === typeof(z) && throw(MethodError(polygamma, (m,z)))
537+
# There is nothing to fallback to, since this didn't work
538+
polygamma(m, x)
539+
end

test/runtests.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,35 @@ end
454454
@test typeof(SF.erfc(a)) == BigFloat
455455
end
456456
end
457+
458+
@testset "Base Julia issue #17474" begin
459+
@test SF.f64(complex(1f0,1f0)) == complex(1.0, 1.0)
460+
@test SF.f64(1f0) == 1.0
461+
462+
@test typeof(SF.eta(big"2")) == BigFloat
463+
@test typeof(SF.zeta(big"2")) == BigFloat
464+
@test typeof(SF.digamma(big"2")) == BigFloat
465+
466+
@test_throws MethodError SF.trigamma(big"2")
467+
@test_throws MethodError SF.trigamma(big"2.0")
468+
@test_throws MethodError SF.invdigamma(big"2")
469+
@test_throws MethodError SF.invdigamma(big"2.0")
470+
471+
@test_throws MethodError SF.eta(Complex(big"2"))
472+
@test_throws MethodError SF.eta(Complex(big"2.0"))
473+
@test_throws MethodError SF.zeta(Complex(big"2"))
474+
@test_throws MethodError SF.zeta(Complex(big"2.0"))
475+
@test_throws MethodError SF.zeta(1.0,big"2")
476+
@test_throws MethodError SF.zeta(1.0,big"2.0")
477+
@test_throws MethodError SF.zeta(big"1.0",2.0)
478+
@test_throws MethodError SF.zeta(big"1",2.0)
479+
480+
481+
@test typeof(SF.polygamma(3, 0x2)) == Float64
482+
@test typeof(SF.polygamma(big"3", 2f0)) == Float32
483+
@test typeof(SF.zeta(1, 2.0)) == Float64
484+
@test typeof(SF.zeta(1, 2f0)) == Float64 # BitIntegers result in Float64 returns
485+
@test typeof(SF.zeta(2f0, complex(2f0,0f0))) == Complex{Float32}
486+
@test typeof(SF.zeta(complex(1,1), 2f0)) == Complex{Float64}
487+
@test typeof(SF.zeta(complex(1), 2.0)) == Complex{Float64}
488+
end

0 commit comments

Comments
 (0)