diff --git a/Project.toml b/Project.toml index dd614606..c73f4e7c 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,6 @@ ArrayInterface = "7.1" ForwardDiff = "0.10.13" GPUArraysCore = "0.1, 0.2" GenericSchur = "0.5.3" -JET = "0.9, 0.10, 0.11" LinearAlgebra = "1.10" Pkg = "1" PrecompileTools = "1" @@ -42,7 +41,6 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" @@ -50,4 +48,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "ForwardDiff", "JET", "Pkg", "Test", "SafeTestsets", "StaticArrays", "Random"] +test = ["Aqua", "ForwardDiff", "Pkg", "Test", "SafeTestsets", "StaticArrays", "Random"] diff --git a/src/exp_generic.jl b/src/exp_generic.jl index b1837432..11d2bcc1 100644 --- a/src/exp_generic.jl +++ b/src/exp_generic.jl @@ -130,7 +130,9 @@ end """ struct ExpMethodGeneric{T} - ExpMethodGeneric()=ExpMethodGeneric{Val{13}}(); + ExpMethodGeneric()=ExpMethodGeneric{Val{13}}() + ExpMethodGeneric(k::Integer)=ExpMethodGeneric{Val{k}}() + ExpMethodGeneric(::Type{T}) where T = ExpMethodGeneric{Val{pade_order_for_type(T)}}() Generic exponential implementation of the method `ExpMethodHigham2005`, for any exp argument `x` for which the functions @@ -138,16 +140,73 @@ for any exp argument `x` for which the functions UniformScaling objects) are defined. The type `T` is used to adjust the number of terms used in the Pade approximants at compile time. +For high-precision types like `BigFloat`, the Padé order is automatically +selected based on the precision to achieve machine-precision accuracy. +You can also manually specify the order: `ExpMethodGeneric(k)` uses a +`(k,k)` Padé approximant. To automatically select based on element type, +use `ExpMethodGeneric(T)` where `T` is the element type. + See "The Scaling and Squaring Method for the Matrix Exponential Revisited" by Higham, Nicholas J. in 2005 for algorithm details. """ struct ExpMethodGeneric{T} end -ExpMethodGeneric() = ExpMethodGeneric{Val(13)}(); +ExpMethodGeneric() = ExpMethodGeneric{Val(13)}() +ExpMethodGeneric(k::Integer) = ExpMethodGeneric{Val{k}()}() + +""" + pade_order_for_type(::Type{T}) where T + +Compute the minimum Padé order k required for machine-precision accuracy +for a given floating-point type T. The (k,k) Padé approximant for exp(x) +has error bounded by (x/2)^(2k+1) / (2k+1)! for |x| ≤ 1. +""" +function pade_order_for_type(::Type{T}) where {T} + # Get precision in bits + p = _precision_bits(T) + # For standard Float64, use the optimized k=13 + p <= 64 && return 13 + # For higher precision, compute required k + # We need: (1/2)^(2k+1) / (2k+1)! < 2^(-p) + # Adding a small buffer for safety + target = big(1) // big(2)^(p + 10) + for k in 13:500 + bound = (big(1) // 2)^(2k + 1) / factorial(big(2k + 1)) + if bound < target + return k + end + end + return 500 # fallback for extremely high precision +end + +_precision_bits(::Type{Float16}) = 11 +_precision_bits(::Type{Float32}) = 24 +_precision_bits(::Type{Float64}) = 53 +_precision_bits(::Type{BigFloat}) = precision(BigFloat) +_precision_bits(::Type{Complex{T}}) where {T} = _precision_bits(T) +_precision_bits(::Type{T}) where {T <: AbstractFloat} = precision(T) +# Fallback for other numeric types (integers, etc.) +_precision_bits(::Type{T}) where {T <: Number} = 53 + +ExpMethodGeneric(::Type{T}) where {T} = ExpMethodGeneric{Val{pade_order_for_type(T)}()}() + +# Extract the element type from various input types +_eltype(x::Number) = typeof(x) +_eltype(x::AbstractArray{T}) where {T} = T + +# Determine if the default k=13 is sufficient for the given element type +_needs_higher_order(::Type{T}) where {T} = _precision_bits(T) > 64 function exponential!( x, method::ExpMethodGeneric{Vk}, cache = alloc_mem(x, method) ) where {Vk} + # For high-precision types with default k=13, automatically use higher order + T = _eltype(x) + if Vk === Val{13}() && _needs_higher_order(T) + k = pade_order_for_type(T) + return exponential!(x, ExpMethodGeneric{Val{k}()}(), cache) + end + nx = opnorm(x, 1) if isnan(nx) || nx > 4611686018427387904 # i.e. 2^62 since it would cause overflow in 2^s # This should (hopefully) ensure that the result is Inf or NaN depending on diff --git a/test/basictests.jl b/test/basictests.jl index c1ae2f88..0212e4eb 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -126,6 +126,34 @@ end @test vec(res) ≈ exp.(ts) end +@testset "Issue 44 - BigFloat precision" begin + using ExponentialUtilities: pade_order_for_type + + # Test that pade_order_for_type returns reasonable values + @test pade_order_for_type(Float64) == 13 + @test pade_order_for_type(BigFloat) >= 24 # depends on precision + + # Test scalar BigFloat accuracy + _x = range(-10, stop = 10, length = 50) + bigfloat_x = big.(_x) + max_err = maximum( + abs, + (x -> exp_generic(x) / exp(x) - 1).(bigfloat_x) + ) / eps(BigFloat) + @test max_err < 100 # Should be within ~100 eps + + # Test with different precisions + for bits in [128, 256, 512] + setprecision(bits) do + x = big"0.5" + result = exp_generic(x) + exact = exp(x) + rel_err = abs(result - exact) / exact + @test rel_err < 100 * eps(BigFloat) + end + end +end + @testset "naive_matmul" begin A = Matrix(reshape((1.0:(23.0^2)) ./ 700, (23, 23))) @test exp_generic(A) ≈ exp(A) diff --git a/test/jet/Project.toml b/test/jet/Project.toml new file mode 100644 index 00000000..0cf8f6e9 --- /dev/null +++ b/test/jet/Project.toml @@ -0,0 +1,6 @@ +[deps] +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +JET = "0.9, 0.10, 0.11" diff --git a/test/jet/jet.jl b/test/jet/jet.jl new file mode 100644 index 00000000..fcc3a428 --- /dev/null +++ b/test/jet/jet.jl @@ -0,0 +1,41 @@ +using ExponentialUtilities, JET, Test + +@testset "JET static analysis" begin + # Test key entry points for type stability and correctness + # Using report_call to check for runtime errors + + @testset "expv" begin + rep = JET.report_call(expv, (Float64, Matrix{Float64}, Vector{Float64})) + @test length(JET.get_reports(rep)) == 0 + end + + @testset "arnoldi" begin + rep = JET.report_call(arnoldi, (Matrix{Float64}, Vector{Float64})) + @test length(JET.get_reports(rep)) == 0 + end + + @testset "phi" begin + rep = JET.report_call(phi, (Matrix{Float64}, Int)) + @test length(JET.get_reports(rep)) == 0 + end + + @testset "exponential!" begin + rep = JET.report_call(ExponentialUtilities.exponential!, (Matrix{Float64},)) + @test length(JET.get_reports(rep)) == 0 + end + + @testset "phiv" begin + rep = JET.report_call(phiv, (Float64, Matrix{Float64}, Vector{Float64}, Int)) + @test length(JET.get_reports(rep)) == 0 + end + + @testset "kiops" begin + rep = JET.report_call(kiops, (Float64, Matrix{Float64}, Vector{Float64})) + @test length(JET.get_reports(rep)) == 0 + end + + @testset "expv_timestep" begin + rep = JET.report_call(expv_timestep, (Float64, Matrix{Float64}, Vector{Float64})) + @test length(JET.get_reports(rep)) == 0 + end +end diff --git a/test/qa.jl b/test/qa.jl index 6dba9775..cfa2c442 100644 --- a/test/qa.jl +++ b/test/qa.jl @@ -1,4 +1,4 @@ -using ExponentialUtilities, Aqua, JET +using ExponentialUtilities, Aqua @testset "Aqua" begin Aqua.find_persistent_tasks_deps(ExponentialUtilities) Aqua.test_ambiguities(ExponentialUtilities, recursive = false) @@ -12,43 +12,3 @@ using ExponentialUtilities, Aqua, JET Aqua.test_unbound_args(ExponentialUtilities) Aqua.test_undefined_exports(ExponentialUtilities) end - -@testset "JET static analysis" begin - # Test key entry points for type stability and correctness - # Using report_call to check for runtime errors - - @testset "expv" begin - rep = JET.report_call(expv, (Float64, Matrix{Float64}, Vector{Float64})) - @test length(JET.get_reports(rep)) == 0 - end - - @testset "arnoldi" begin - rep = JET.report_call(arnoldi, (Matrix{Float64}, Vector{Float64})) - @test length(JET.get_reports(rep)) == 0 - end - - @testset "phi" begin - rep = JET.report_call(phi, (Matrix{Float64}, Int)) - @test length(JET.get_reports(rep)) == 0 - end - - @testset "exponential!" begin - rep = JET.report_call(ExponentialUtilities.exponential!, (Matrix{Float64},)) - @test length(JET.get_reports(rep)) == 0 - end - - @testset "phiv" begin - rep = JET.report_call(phiv, (Float64, Matrix{Float64}, Vector{Float64}, Int)) - @test length(JET.get_reports(rep)) == 0 - end - - @testset "kiops" begin - rep = JET.report_call(kiops, (Float64, Matrix{Float64}, Vector{Float64})) - @test length(JET.get_reports(rep)) == 0 - end - - @testset "expv_timestep" begin - rep = JET.report_call(expv_timestep, (Float64, Matrix{Float64}, Vector{Float64})) - @test length(JET.get_reports(rep)) == 0 - end -end diff --git a/test/runtests.jl b/test/runtests.jl index ce7c7fd4..68fc7ea6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,12 @@ function activate_gpu_env() return Pkg.instantiate() end +function activate_jet_env() + Pkg.activate("jet") + Pkg.develop(PackageSpec(path = dirname(@__DIR__))) + return Pkg.instantiate() +end + @time begin if GROUP == "All" || GROUP == "Core" @time @safetestset "Quality Assurance" include("qa.jl") @@ -20,4 +26,9 @@ end activate_gpu_env() @time @safetestset "GPU Tests" include("gpu/gputests.jl") end + + if GROUP == "JET" + activate_jet_env() + @time @safetestset "JET Tests" include("jet/jet.jl") + end end