Skip to content

faster exp* #136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 81 additions & 163 deletions src/math/elementary/explog.jl
Original file line number Diff line number Diff line change
@@ -1,95 +1,69 @@
function exp(a::DoubleFloat{T}) where {T<:IEEEFloat}
isnan(a) && return a
isinf(a) && return(signbit(a) ? zero(DoubleFloat{T}) : a)

if iszero(HI(a))
return one(DoubleFloat{T})
elseif isone(abs(HI(a))) && iszero(LO(a))
if HI(a) >= zero(T)
return DoubleFloat{T}(2.718281828459045, 1.4456468917292502e-16)
else # isone(-HI(a)) && iszero(LO(a))
return DoubleFloat{T}(0.36787944117144233, -1.2428753672788363e-17)
for FT in (DoubleFloat{Float16}, DoubleFloat{Float32})
for func in (:exp2, :exp, :exp10, :expm1, :log2, :log, :log10, :log1p)
@eval ($func)(a::$FT) = $FT(($func)(Float64(a)))
end
elseif abs(HI(a)) >= 709.0
if (HI(a) <= -709.0)
return zero(DoubleFloat{T})
else # HI(a) >= 709.0
return inf(DoubleFloat{T})
end
end

return calc_exp(a)
end

function exp_taylor(a::DoubleFloat{T}) where {T<:IEEEFloat}
x = a
x2 = x*x
x3 = x*x2
x4 = x2*x2
x5 = x2*x3
x10 = x5*x5
x15 = x5*x10
x20 = x10*x10
x25 = x10*x15

z = x + inv_fact[2]*x2 + inv_fact[3]*x3 + inv_fact[4]*x4
z2 = x5 * (inv_fact[5] + x*inv_fact[6] + x2*inv_fact[7] +
x3*inv_fact[8] + x4*inv_fact[9])
z3 = x10 * (inv_fact[10] + x*inv_fact[11] + x2*inv_fact[12] +
x3*inv_fact[13] + x4*inv_fact[14])
z4 = x15 * (inv_fact[15] + x*inv_fact[16] + x2*inv_fact[17] +
x3*inv_fact[18] + x4*inv_fact[19])
z5 = x20 * (inv_fact[20] + x*inv_fact[21] + x2*inv_fact[22] +
x3*inv_fact[23] + x4*inv_fact[24])
z6 = x25 * (inv_fact[25] + x*inv_fact[26] + x2*inv_fact[27])

((((z6+z5)+z4)+z3)+z2)+z + one(DoubleFloat{T})
end


@inline exp_zero_half(a::DoubleFloat{T}) where {T<:IEEEFloat} = exp_taylor(a)

@inline function exp_half_one(a::DoubleFloat{T}) where {T<:IEEEFloat}
z = mul_by_half(a)
z = exp_zero_half(z)
z = square(z)
return z
exp(a::DoubleFloat{Float64}) = exp2(a*Double64(1.4426950408889634, 2.0355273740931033e-17))
exp10(a::DoubleFloat{Float64}) = exp2(a*Double64(3.321928094887362, 1.661617516973592e-16))
function exp2(a::DoubleFloat{Float64})
abshi = abs(HI(a))
isnan(a) && return a
iszero(abshi) && return one(Double64)
if abshi > 1023.5
return (HI(a) < 0) ? zero(Double64) : inf(Double64)
end
return calc_exp2(a)
end


function mul_by_half(r::DoubleFloat{T}) where {T<:IEEEFloat}
frhi, xphi = frexp(HI(r))
frlo, xplo = frexp(LO(r))
xphi -= 1
xplo -= 1
hi = ldexp(frhi, xphi)
lo = ldexp(frlo, xplo)
return DoubleFloat{T}(hi, lo)
@inline function exthorner(x, p::Tuple)
hi, lo = p[end], zero(x)
for i in length(p)-1:-1:1
pi = p[i]
prod = hi*x
err1 = fma(hi, x, -prod)
hi, err2 = two_sum(pi,prod)
lo = fma(lo, x, err1 + err2)
end
return hi, lo
end

function mul_by_two(r::DoubleFloat{T}) where {T<:IEEEFloat}
frhi, xphi = frexp(HI(r))
frlo, xplo = frexp(LO(r))
xphi += 1
xplo += 1
hi = ldexp(frhi, xphi)
lo = ldexp(frlo, xplo)
return DoubleFloat{T}(hi, lo)
end
const coefs = Tuple(log(big(2))^n/factorial(big(n)) for n in 1:10)
const coefs_hi = Float64.(coefs)
const coefs_lo = Float64.(coefs .- coefs_hi)

function mul_pow2(r::DoubleFloat{T}, n::Int) where {T<:IEEEFloat}
frhi, xphi = frexp(HI(r))
frlo, xplo = frexp(LO(r))
xphi += n
xplo += n
hi = ldexp(frhi, xphi)
lo = ldexp(frlo, xplo)
return DoubleFloat{T}(hi, lo)
function exp_kernel(x::Float64)
hi, lo = exthorner(x, coefs_hi)
lo2 = evalpoly(x, coefs_lo)
hix = hi*x
return Double64(hix, fma(lo, x, fma(lo2, x, fma(hi, x, -hix))))
end

function mul_pwr2(r::DoubleFloat{T}, n::Real) where {T<:IEEEFloat}
m = 2.0^n
return DoubleFloat{T}(HI(r)*m, LO(r)*m)
function _make_exp_table(size, n=1)
t_array = zeros(Double64, 16);
for j in 1:size
val = 2.0^(BigFloat(j-1)/(16*n))
t_array[j] = val
end
return Tuple(t_array)
end
const T1 = _make_exp_table(16)
const T2 = _make_exp_table(16, 16)

function calc_exp2(a::Double64)
x = HI(a)
N = round(Int, 256*x)
k = N>>8
j1 = T1[(N&255)>>4 + 1]
j2 = T2[N&15 + 1]
r = fma((-1/256), N, x)
poly = exp_kernel(r)
poly_lo = exp_kernel(LO(a))
e2k = exp2(k)
lo_part = fma(poly, poly_lo, poly_lo) + poly
ans = fma(j1*j2, lo_part, j1*j2)
return e2k*ans
end

function Base.:(^)(r::DoubleFloat{T}, n::Int) where {T<:IEEEFloat}
Expand Down Expand Up @@ -136,73 +110,16 @@ function Base.:(^)(r::Int, n::DoubleFloat{T}) where {T<:IEEEFloat}
end
end

function calc_exp(a::DoubleFloat{T}) where {T<:IEEEFloat}
is_neg = signbit(HI(a))
xabs = is_neg ? -a : a
xintpart = modf(xabs)[2]
xintpart = xintpart.hi + xintpart.lo
xint = Int64(xintpart)
xfrac = xabs - T(xint)

if 0 < xint <= 64
zint = exp_int[xint]
elseif xint === zero(Int64)
zint = zero(DoubleFloat{T})
else
dv, rm = divrem(xint, 64)
zint = exp_int[64]^dv
if rm > 0
zint = zint * exp_int[rm]
end
end

# exp(xfrac)
if HI(xfrac) < 0.5
zfrac = exp_zero_half(xfrac)
elseif HI(xfrac) > 0.5
zfrac = exp_half_one(xfrac)
else
if LO(xfrac) == 0.0
zfrac = DoubleFloat{T}(1.6487212707001282, -4.731568479435833e-17)
elseif signbit(LO(xfrac))
zfrac = exp_zero_half(xfrac)
else
zfrac = exp_half_one(xfrac)
end
end

z = HI(zint) == zero(T) ? zfrac : zint * zfrac
if is_neg
z = inv(z)
end

return z
end

function expm1(a::DoubleFloat{T}) where {T<:IEEEFloat}
function expm1(a::Double64)
isnan(a) && return a
isinf(a) && return(signbit(a) ? zero(DoubleFloat{T}) : a)
u = exp(a)
# temp fix of if (u == one(DoubleFloat{T}))
if (isone(u.hi))
a
elseif (u-1.0 == -one(DoubleFloat{T}))
-one(DoubleFloat{T})
else
a*(u-1.0) / log(u)
abshi = abs(HI(a))
if abshi < .5
u = a*Double64(1.4426950408889634, 2.0355273740931033e-17)
a = exp_kernel(HI(u))
return fma(a, LO(u), a)
end
end

function exp2(a::DoubleFloat{T}) where {T<:IEEEFloat}
isnan(a) && return a
isinf(a) && return(signbit(a) ? zero(DoubleFloat{T}) : a)
return DoubleFloat{T}(2)^a
end

function exp10(a::DoubleFloat{T}) where {T<:IEEEFloat}
isnan(a) && return a
isinf(a) && return(signbit(a) ? zero(DoubleFloat{T}) : a)
return DoubleFloat{T}(10)^a
return exp(a)-1
end

#=
Expand Down Expand Up @@ -233,46 +150,47 @@ end
return u
end
=#
function mul_by_two(r::DoubleFloat{T}) where {T<:IEEEFloat}
frhi, xphi = frexp(HI(r))
frlo, xplo = frexp(LO(r))
xphi += 1
xplo += 1
hi = ldexp(frhi, xphi)
lo = ldexp(frlo, xplo)
return DoubleFloat{T}(hi, lo)
end

function log(x::DoubleFloat{T}) where {T<:IEEEFloat}
function log(x::Double64)
isnan(x) && return x
isinf(x) && !signbit(x) && return x
x === zero(DoubleFloat{T}) && return neginf(DoubleFloat{T})
y = DoubleFloat(log(HI(x)), zero(T))
x === zero(Double64) && return neginf(Double64)
y = Double64(log(HI(x)), 0.0)
z = exp(y)
adj = (z - x) / (z + x)
adj = mul_by_two(adj)
y = y - adj
return y
end


function log1p(x::DoubleFloat{T}) where {T<:IEEEFloat}
function log1p(x::Double64)
isnan(x) && return x
isinf(x) && !signbit(x) && return
isinf(x) && !signbit(x) && return
u = 1.0 + x
if u == one(DoubleFloat{T})
if u == one(Double64)
x
else
log(u)*x/(u-1.0)
end
end

logten(::Type{DoubleFloat{Float64}}) = Double64(2.302585092994046, -2.1707562233822494e-16)
logtwo(::Type{DoubleFloat{Float64}}) = Double64(0.6931471805599453, 2.3190468138462996e-17)
logtwo(::Type{DoubleFloat{Float32}}) = Double32(0.6931472, -1.9046542e-9)
logten(::Type{DoubleFloat{Float32}}) = Double32(2.3025851, -3.1975436e-8)
logtwo(::Type{DoubleFloat{Float16}}) = Double16(0.6934, -0.0002122)
logten(::Type{DoubleFloat{Float16}}) = Double16(2.303, -0.0001493)

function log2(x::DoubleFloat{T}) where {T<:IEEEFloat}
function log2(x::Double64)
isnan(x) && return x
isinf(x) && !signbit(x) && return x
log(x) / logtwo(DoubleFloat{T})
log(x) / Double64(0.6931471805599453, 2.3190468138462996e-17)
end

function log10(x::DoubleFloat{T}) where {T<:IEEEFloat}
function log10(x::Double64)
isnan(x) && return x
isinf(x) && !signbit(x) && return x
log(x) / logten(DoubleFloat{T})
log(x) / Double64(2.302585092994046, -2.1707562233822494e-16)
end