diff --git a/src/DualNumbers.jl b/src/DualNumbers.jl index 53e8c53..bd03d11 100644 --- a/src/DualNumbers.jl +++ b/src/DualNumbers.jl @@ -3,6 +3,7 @@ module DualNumbers using SpecialFunctions import NaNMath import Calculus +import Random include("dual.jl") diff --git a/src/dual.jl b/src/dual.jl index c82e8c0..3312126 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -21,7 +21,7 @@ const DualComplex64 = Dual{ComplexF16} Base.convert(::Type{Dual{T}}, z::Dual{T}) where {T<:ReComp} = z Base.convert(::Type{Dual{T}}, z::Dual) where {T<:ReComp} = Dual{T}(convert(T, value(z)), convert(T, epsilon(z))) Base.convert(::Type{Dual{T}}, x::Number) where {T<:ReComp} = Dual{T}(convert(T, x), convert(T, 0)) -Base.convert(::Type{T}, z::Dual) where {T<:ReComp} = (epsilon(z)==0 ? convert(T, value(z)) : throw(InexactError())) +Base.convert(::Type{T}, z::Dual) where {T<:ReComp} = (iszero(epsilon(z)) ? convert(T, value(z)) : throw(InexactError())) Base.promote_rule(::Type{Dual{T}}, ::Type{Dual{S}}) where {T<:ReComp,S<:ReComp} = Dual{promote_type(T, S)} Base.promote_rule(::Type{Dual{T}}, ::Type{S}) where {T<:ReComp,S<:ReComp} = Dual{promote_type(T, S)} @@ -175,7 +175,7 @@ Base.isless(z::Dual{<:Real},w::Dual{<:Real}) = value(z) < value(w) Base.isless(z::Real,w::Dual{<:Real}) = z < value(w) Base.isless(z::Dual{<:Real},w::Real) = value(z) < w -Base.hash(z::Dual) = (x = hash(value(z)); epsilon(z)==0 ? x : bitmix(x,hash(epsilon(z)))) +Base.hash(z::Dual) = (x = hash(value(z)); iszero(epsilon(z)) ? x : bitmix(x,hash(epsilon(z)))) Base.float(z::Union{Dual{T}, Dual{Complex{T}}}) where {T<:AbstractFloat} = z Base.complex(z::Dual{<:Complex}) = z @@ -189,6 +189,21 @@ Base.ceil( ::Type{T}, z::Dual) where {T<:Real} = ceil( T, value(z)) Base.trunc(::Type{T}, z::Dual) where {T<:Real} = trunc(T, value(z)) Base.round(::Type{T}, z::Dual) where {T<:Real} = round(T, value(z)) +Base.zero(::Type{Dual{T}}) where {T} = Dual(zero(T), zero(T)) +Base.zero(x::Dual{T}) where {T} = zero(typeof(x)) +Base.iszero(z::Dual{T}) where {T} = iszero(value(z)) + +Base.one(::Type{Dual{T}}) where {T} = Dual(one(T), zero(T)) +Base.one(::Dual{T}) where {T} = one(Dual{T}) +Base.isone(z::Dual{T}) where {T} = isone(value(z)) + +Base.rand(r::Random.AbstractRNG, ::Random.SamplerType{Dual{T}}) where {T} = Dual{T}(rand(r, T), rand(r, T)) +Base.randn(r::Random.AbstractRNG, ::Type{Dual{T}}) where {T} = Dual{T}(randn(r, T), randn(r, T)) + +Base.rtoldefault(::Type{Dual{T}}) where {T} = Base.rtoldefault(T) +Base.copysign(x::Dual, y::Dual) = Dual(copysign(value(x), value(y)), + copysign(epsilon(x), epsilon(y))) + for op in (:real, :imag, :conj, :float, :complex) @eval Base.$op(z::Dual) = Dual($op(value(z)), $op(epsilon(z))) end @@ -201,8 +216,8 @@ Base.abs(z::Dual{<:Real}) = z ≥ 0 ? z : -z Base.angle(z::Dual{<:Real}) = z ≥ 0 ? zero(z) : one(z)*π function Base.angle(z::Dual{Complex{T}}) where T<:Real - if z == 0 - if imag(epsilon(z)) == 0 + if iszero(z) + if iszero(imag(epsilon(z))) Dual(zero(T), zero(T)) else Dual(zero(T), convert(T, Inf)) @@ -212,8 +227,8 @@ function Base.angle(z::Dual{Complex{T}}) where T<:Real end end -Base.flipsign(x::Dual,y::Dual) = y == 0 ? flipsign(x, epsilon(y)) : flipsign(x, value(y)) -Base.flipsign(x, y::Dual) = y == 0 ? flipsign(x, epsilon(y)) : flipsign(x, value(y)) +Base.flipsign(x::Dual,y::Dual) = iszero(y) ? flipsign(x, epsilon(y)) : flipsign(x, value(y)) +Base.flipsign(x, y::Dual) = iszero(y) ? flipsign(x, epsilon(y)) : flipsign(x, value(y)) Base.flipsign(x::Dual, y) = dual(flipsign(value(x), y), flipsign(epsilon(x), y)) # algebraic definitions @@ -233,7 +248,7 @@ Base.:-(z::Number, w::Dual) = Dual(z-value(w), -epsilon(w)) Base.:-(z::Dual, w::Number) = Dual(value(z)-w, epsilon(z)) # avoid ambiguous definition with Bool*Number -Base.:*(x::Bool, z::Dual) = ifelse(x, z, ifelse(signbit(real(value(z)))==0, zero(z), -zero(z))) +Base.:*(x::Bool, z::Dual) = ifelse(x, z, ifelse(iszero(signbit(real(value(z)))), zero(z), -zero(z))) Base.:*(x::Dual, z::Bool) = z*x Base.:*(z::Dual, w::Dual) = Dual(value(z)*value(w), epsilon(z)*value(w)+value(z)*epsilon(w)) @@ -246,7 +261,7 @@ Base.:/(z::Dual, x::Number) = Dual(value(z)/x, epsilon(z)/x) for f in [:(Base.:^), :(NaNMath.pow)] @eval function ($f)(z::Dual, w::Dual) - if epsilon(w) == 0.0 + if iszero(epsilon(w)) return $f(z, value(w)) end val = $f(value(z), value(w)) diff --git a/test/automatic_differentiation_test.jl b/test/automatic_differentiation_test.jl index b6126be..f785591 100644 --- a/test/automatic_differentiation_test.jl +++ b/test/automatic_differentiation_test.jl @@ -127,6 +127,7 @@ z = Dual(1.0+1.0im,cis(π/4)) z = Dual(1.0+1.0im,cis(π/2)) @test abs(z) ≡ sqrt(2) + 1/sqrt(2)*ɛ + # tests vectorized methods const zv = dual.(collect(1.0:10.0), ones(10)) @@ -173,3 +174,12 @@ end @test value(3) == 3 @test epsilon(44.0) ≈ 0.0 + +@test one(Dual{Float64}) == Dual(1,0) +@test isone(Dual(1,0)) +@test isone(one(rand(Dual{Float64}))) + +@test zero(Dual{Float64}) == Dual(0,0) +@test iszero(Dual(0,0)) + +@test copysign(Dual(-1,-2), Dual(2,3)) == Dual(1,2)