Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ ArrayInterface = "7.1"
ForwardDiff = "0.10.13"
GPUArraysCore = "0.1, 0.2"
GenericSchur = "0.5.3"
JET = "0.9, 0.10, 0.11"
LinearAlgebra = "1.10"
Pkg = "1"
PrecompileTools = "1"
Expand All @@ -42,12 +41,11 @@ julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "ForwardDiff", "JET", "Pkg", "Test", "SafeTestsets", "StaticArrays", "Random"]
test = ["Aqua", "ForwardDiff", "Pkg", "Test", "SafeTestsets", "StaticArrays", "Random"]
63 changes: 61 additions & 2 deletions src/exp_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,83 @@ end

"""
struct ExpMethodGeneric{T}
ExpMethodGeneric()=ExpMethodGeneric{Val{13}}();
ExpMethodGeneric()=ExpMethodGeneric{Val{13}}()
ExpMethodGeneric(k::Integer)=ExpMethodGeneric{Val{k}}()
ExpMethodGeneric(::Type{T}) where T = ExpMethodGeneric{Val{pade_order_for_type(T)}}()

Generic exponential implementation of the method `ExpMethodHigham2005`,
for any exp argument `x` for which the functions
`LinearAlgebra.opnorm`, `+`, `*`, `^`, and `/` (including addition with
UniformScaling objects) are defined. The type `T` is used to adjust the
number of terms used in the Pade approximants at compile time.

For high-precision types like `BigFloat`, the Padé order is automatically
selected based on the precision to achieve machine-precision accuracy.
You can also manually specify the order: `ExpMethodGeneric(k)` uses a
`(k,k)` Padé approximant. To automatically select based on element type,
use `ExpMethodGeneric(T)` where `T` is the element type.

See "The Scaling and Squaring Method for the Matrix Exponential Revisited"
by Higham, Nicholas J. in 2005 for algorithm details.
"""
struct ExpMethodGeneric{T} end
ExpMethodGeneric() = ExpMethodGeneric{Val(13)}();
ExpMethodGeneric() = ExpMethodGeneric{Val(13)}()
ExpMethodGeneric(k::Integer) = ExpMethodGeneric{Val{k}()}()

"""
pade_order_for_type(::Type{T}) where T

Compute the minimum Padé order k required for machine-precision accuracy
for a given floating-point type T. The (k,k) Padé approximant for exp(x)
has error bounded by (x/2)^(2k+1) / (2k+1)! for |x| ≤ 1.
"""
function pade_order_for_type(::Type{T}) where {T}
# Get precision in bits
p = _precision_bits(T)
# For standard Float64, use the optimized k=13
p <= 64 && return 13
# For higher precision, compute required k
# We need: (1/2)^(2k+1) / (2k+1)! < 2^(-p)
# Adding a small buffer for safety
target = big(1) // big(2)^(p + 10)
for k in 13:500
bound = (big(1) // 2)^(2k + 1) / factorial(big(2k + 1))
if bound < target
return k
end
end
return 500 # fallback for extremely high precision
end

_precision_bits(::Type{Float16}) = 11
_precision_bits(::Type{Float32}) = 24
_precision_bits(::Type{Float64}) = 53
_precision_bits(::Type{BigFloat}) = precision(BigFloat)
_precision_bits(::Type{Complex{T}}) where {T} = _precision_bits(T)
_precision_bits(::Type{T}) where {T <: AbstractFloat} = precision(T)
# Fallback for other numeric types (integers, etc.)
_precision_bits(::Type{T}) where {T <: Number} = 53

ExpMethodGeneric(::Type{T}) where {T} = ExpMethodGeneric{Val{pade_order_for_type(T)}()}()

# Extract the element type from various input types
_eltype(x::Number) = typeof(x)
_eltype(x::AbstractArray{T}) where {T} = T

# Determine if the default k=13 is sufficient for the given element type
_needs_higher_order(::Type{T}) where {T} = _precision_bits(T) > 64

function exponential!(
x, method::ExpMethodGeneric{Vk},
cache = alloc_mem(x, method)
) where {Vk}
# For high-precision types with default k=13, automatically use higher order
T = _eltype(x)
if Vk === Val{13}() && _needs_higher_order(T)
k = pade_order_for_type(T)
return exponential!(x, ExpMethodGeneric{Val{k}()}(), cache)
end

nx = opnorm(x, 1)
if isnan(nx) || nx > 4611686018427387904 # i.e. 2^62 since it would cause overflow in 2^s
# This should (hopefully) ensure that the result is Inf or NaN depending on
Expand Down
28 changes: 28 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,34 @@ end
@test vec(res) ≈ exp.(ts)
end

@testset "Issue 44 - BigFloat precision" begin
using ExponentialUtilities: pade_order_for_type

# Test that pade_order_for_type returns reasonable values
@test pade_order_for_type(Float64) == 13
@test pade_order_for_type(BigFloat) >= 24 # depends on precision

# Test scalar BigFloat accuracy
_x = range(-10, stop = 10, length = 50)
bigfloat_x = big.(_x)
max_err = maximum(
abs,
(x -> exp_generic(x) / exp(x) - 1).(bigfloat_x)
) / eps(BigFloat)
@test max_err < 100 # Should be within ~100 eps

# Test with different precisions
for bits in [128, 256, 512]
setprecision(bits) do
x = big"0.5"
result = exp_generic(x)
exact = exp(x)
rel_err = abs(result - exact) / exact
@test rel_err < 100 * eps(BigFloat)
end
end
end

@testset "naive_matmul" begin
A = Matrix(reshape((1.0:(23.0^2)) ./ 700, (23, 23)))
@test exp_generic(A) ≈ exp(A)
Expand Down
6 changes: 6 additions & 0 deletions test/jet/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[deps]
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
JET = "0.9, 0.10, 0.11"
41 changes: 41 additions & 0 deletions test/jet/jet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using ExponentialUtilities, JET, Test

@testset "JET static analysis" begin
# Test key entry points for type stability and correctness
# Using report_call to check for runtime errors

@testset "expv" begin
rep = JET.report_call(expv, (Float64, Matrix{Float64}, Vector{Float64}))
@test length(JET.get_reports(rep)) == 0
end

@testset "arnoldi" begin
rep = JET.report_call(arnoldi, (Matrix{Float64}, Vector{Float64}))
@test length(JET.get_reports(rep)) == 0
end

@testset "phi" begin
rep = JET.report_call(phi, (Matrix{Float64}, Int))
@test length(JET.get_reports(rep)) == 0
end

@testset "exponential!" begin
rep = JET.report_call(ExponentialUtilities.exponential!, (Matrix{Float64},))
@test length(JET.get_reports(rep)) == 0
end

@testset "phiv" begin
rep = JET.report_call(phiv, (Float64, Matrix{Float64}, Vector{Float64}, Int))
@test length(JET.get_reports(rep)) == 0
end

@testset "kiops" begin
rep = JET.report_call(kiops, (Float64, Matrix{Float64}, Vector{Float64}))
@test length(JET.get_reports(rep)) == 0
end

@testset "expv_timestep" begin
rep = JET.report_call(expv_timestep, (Float64, Matrix{Float64}, Vector{Float64}))
@test length(JET.get_reports(rep)) == 0
end
end
42 changes: 1 addition & 41 deletions test/qa.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ExponentialUtilities, Aqua, JET
using ExponentialUtilities, Aqua
@testset "Aqua" begin
Aqua.find_persistent_tasks_deps(ExponentialUtilities)
Aqua.test_ambiguities(ExponentialUtilities, recursive = false)
Expand All @@ -12,43 +12,3 @@ using ExponentialUtilities, Aqua, JET
Aqua.test_unbound_args(ExponentialUtilities)
Aqua.test_undefined_exports(ExponentialUtilities)
end

@testset "JET static analysis" begin
# Test key entry points for type stability and correctness
# Using report_call to check for runtime errors

@testset "expv" begin
rep = JET.report_call(expv, (Float64, Matrix{Float64}, Vector{Float64}))
@test length(JET.get_reports(rep)) == 0
end

@testset "arnoldi" begin
rep = JET.report_call(arnoldi, (Matrix{Float64}, Vector{Float64}))
@test length(JET.get_reports(rep)) == 0
end

@testset "phi" begin
rep = JET.report_call(phi, (Matrix{Float64}, Int))
@test length(JET.get_reports(rep)) == 0
end

@testset "exponential!" begin
rep = JET.report_call(ExponentialUtilities.exponential!, (Matrix{Float64},))
@test length(JET.get_reports(rep)) == 0
end

@testset "phiv" begin
rep = JET.report_call(phiv, (Float64, Matrix{Float64}, Vector{Float64}, Int))
@test length(JET.get_reports(rep)) == 0
end

@testset "kiops" begin
rep = JET.report_call(kiops, (Float64, Matrix{Float64}, Vector{Float64}))
@test length(JET.get_reports(rep)) == 0
end

@testset "expv_timestep" begin
rep = JET.report_call(expv_timestep, (Float64, Matrix{Float64}, Vector{Float64}))
@test length(JET.get_reports(rep)) == 0
end
end
11 changes: 11 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ function activate_gpu_env()
return Pkg.instantiate()
end

function activate_jet_env()
Pkg.activate("jet")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
return Pkg.instantiate()
end

@time begin
if GROUP == "All" || GROUP == "Core"
@time @safetestset "Quality Assurance" include("qa.jl")
Expand All @@ -20,4 +26,9 @@ end
activate_gpu_env()
@time @safetestset "GPU Tests" include("gpu/gputests.jl")
end

if GROUP == "JET"
activate_jet_env()
@time @safetestset "JET Tests" include("jet/jet.jl")
end
end
Loading