From aa5902a5b65e5367076ba1c2aac9b7e902649215 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 22 Jan 2025 22:09:31 +0100 Subject: [PATCH] Revert "Add generic fallback to all scalar functions" This reverts commit 05dbeac95b7cb5104d0f0d15e37af4dedc6327c8. --- src/NaNMath.jl | 17 ++--------------- test/runtests.jl | 15 --------------- 2 files changed, 2 insertions(+), 30 deletions(-) diff --git a/src/NaNMath.jl b/src/NaNMath.jl index fe6fa26..ef6937c 100644 --- a/src/NaNMath.jl +++ b/src/NaNMath.jl @@ -9,20 +9,9 @@ for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10, Base.@assume_effects :total ($f)(x::Float64) = ccall(($(string(f)),libm), Float64, (Float64,), x) Base.@assume_effects :total ($f)(x::Float32) = ccall(($(string(f,"f")),libm), Float32, (Float32,), x) ($f)(x::Real) = ($f)(float(x)) - if $f !== :lgamma - ($f)(x) = (Base.$f)(x) - end end end -for f in (:sqrt,) - @eval ($f)(x) = (Base.$f)(x) -end - -for f in (:max, :min) - @eval ($f)(x, y) = (Base.$f)(x, y) -end - # Would be more efficient to remove the domain check in Base.sqrt(), # but this doesn't seem easy to do. Base.@assume_effects :nothrow sqrt(x::T) where {T<:Union{Float16, Float32, Float64}} = x < 0.0 ? T(NaN) : Base.sqrt(x) @@ -34,13 +23,11 @@ Base.@assume_effects :total pow(x::Float64, y::Float64) = ccall((:pow,libm), Fl Base.@assume_effects :total pow(x::Float32, y::Float32) = ccall((:powf,libm), Float32, (Float32,Float32), x, y) # We `promote` first before converting to floating pointing numbers to ensure that # e.g. `pow(::Float32, ::Int)` ends up calling `pow(::Float32, ::Float32)` -pow(x::Real, y::Real) = pow(promote(x, y)...) -pow(x::T, y::T) where {T<:Real} = pow(float(x), float(y)) -pow(x, y) = ^(x, y) +pow(x::Number, y::Number) = pow(promote(x, y)...) +pow(x::T, y::T) where {T<:Number} = pow(float(x), float(y)) # The following combinations are safe, so we can fall back to ^ pow(x::Number, y::Integer) = x^y -pow(x::Real, y::Integer) = x^y pow(x::Complex, y::Complex) = x^y """ diff --git a/test/runtests.jl b/test/runtests.jl index 5cfbac5..ebe3af6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -216,18 +216,3 @@ end @test NaNMath.argmin(exp,x) === -1.0 @test NaNMath.argmax(exp,x) === 3.0 end - -# Test forwarding -x = 1 + 2im -for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10, - :log1p, :sqrt) - @test @eval (NaNMath.$f)(x) == $f(x) -end - -struct A end -Base.isless(::A, ::A) = false -y = A() -for f in (:max, :min) - @test @eval (NaNMath.$f)(y, y) == $f(y, y) -end -@test NaNMath.pow(x, x) == ^(x, x)