Skip to content

Commit 4d462c2

Browse files
committed
New methods to deal with the numerical precision
1 parent d6e0c48 commit 4d462c2

File tree

5 files changed

+252
-0
lines changed

5 files changed

+252
-0
lines changed

NEWS.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@ This page describes the most important changes in `TypeUtils`. The format is bas
44
[Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to
55
[Semantic Versioning](https://semver.org).
66

7+
## Unreleased
8+
9+
## Added
10+
11+
- `get_precision(x)` yields the numerical precision of `x`.
12+
13+
- `adapt_precision(T, x)` yields a version of `x` with numerical precision `T`.
14+
15+
716
## Version 1.8.0 (2025-06-19)
817

918
### Added

ext/TypeUtilsUnitfulExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,12 @@ end
2626
# Extend unitless (only needed for values).
2727
TypeUtils.unitless(x::AbstractQuantity) = ustrip(x)
2828

29+
TypeUtils.get_precision(::Type{<:AbstractQuantity{T}}) where {T} = get_precision(T)
30+
31+
TypeUtils.adapt_precision(::Type{T}, x::Quantity{T,D,U}) where {T<:Precision,D,U} = x
32+
TypeUtils.adapt_precision(::Type{T}, x::Quantity{S,D,U}) where {T<:Precision,D,U,S} =
33+
Quantity{adapt_precision(T, S), D, U}(x)
34+
TypeUtils.adapt_precision(::Type{T}, ::Type{Quantity{S,D,U}}) where {T<:Precision,S,D,U} =
35+
Quantity{adapt_precision(T, S), D, U}
36+
2937
end # module

src/TypeUtils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ export
55
ArrayAxes,
66
ArrayAxis,
77
ArrayShape,
8+
Precision,
89
RelaxedArrayShape,
910
AbstractTypeStableFunction,
1011
TypeStableFunction,
12+
adapt_precision,
1113
as,
1214
as_array_axes,
1315
as_array_axis,
@@ -25,6 +27,7 @@ export
2527
destructure!,
2628
destructure,
2729
floating_point_type,
30+
get_precision,
2831
is_signed,
2932
nearest,
3033
new_array,
@@ -56,6 +59,7 @@ include("methods.jl")
5659
include("numbers.jl")
5760
include("arrays.jl")
5861
include("funcs.jl")
62+
include("precision.jl")
5963
include("structs.jl")
6064

6165
function __init__()

src/types.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ with [`adapt_precision`](@ref).
150150
"""
151151
const Precision = Union{Float16,Float32,Float64,BigFloat}
152152

153+
const default_precision = Float64
154+
153155
"""
154156
c = TypeUtils.Converter(f, T::Type)
155157

test/runtests.jl

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,48 @@ using OffsetArrays
2323
using Test
2424
using Base: OneTo
2525

26+
"""
27+
x ≗ y
28+
29+
yields whether `x` and `y` have the same element types, the same axes, and the same values
30+
(in the sense of `isequal`). It can be seen as a shortcut for:
31+
32+
eltype(x) == eltype(y) && axes(x) == axes(y) && all(isequal, x, y)
33+
34+
"""
35+
(x::Any, y::Any) = false
36+
(x::T, y::T) where {T} = isequal(x, y)
37+
function (x::AbstractArray{T,N}, y::AbstractArray{T,N}) where {T,N}
38+
axes(x) == axes(y) || return false
39+
@inbounds for i in eachindex(x, y)
40+
isequal(x[i], y[i]) || return false
41+
end
42+
return true
43+
end
44+
45+
# Check for approximate equality for numbers, arrays, and factorizations.
46+
(x::Any, y::Any; kwds...) = isequal(x, y)
47+
for type in (Number, AbstractArray, LinearAlgebra.AbstractQ,)
48+
@eval (x::$type, y::$type; kwds...) = isapprox(x, y; kwds...)
49+
end
50+
function (x::Factorization, y::Factorization; kwds...)
51+
x === y && return true
52+
parameterless(typeof(x)) == parameterless(typeof(y)) || return false
53+
axes(x) == axes(y) || return false
54+
# NOTE Must compare properties, not fields and do not use `hasproperty` which only
55+
# appears in Julia 1.2.
56+
ppt_x = propertynames(x)
57+
ppt_y = propertynames(y)
58+
for key in ppt_x
59+
key ppt_y || return false
60+
end
61+
for key in ppt_y
62+
key ppt_x || return false
63+
(getproperty(x, key), getproperty(y, key); kwds...) || return false
64+
end
65+
return true
66+
end
67+
2668
struct TestUnitRange{T<:Real} <: AbstractUnitRange{T}
2769
length::Int
2870
start::T
@@ -801,7 +843,194 @@ same_value_and_type(x::T, y::T) where {T} = (x === y) || (x == y)
801843
@test_throws InexactError scale!(Val(1), copyto!(B, A), alpha)
802844
@test_throws InexactError scale!(Val(2), copyto!(B, A), alpha)
803845
@test_throws InexactError scale!(Val(3), copyto!(B, A), alpha)
846+
end
847+
848+
@testset "get_precision" begin
849+
850+
@test @inferred(get_precision(Symbol )) == AbstractFloat
851+
@test @inferred(get_precision(String )) == AbstractFloat
852+
@test @inferred(get_precision(Bool )) == AbstractFloat
853+
@test @inferred(get_precision(UInt8 )) == AbstractFloat
854+
@test @inferred(get_precision(UInt16 )) == AbstractFloat
855+
@test @inferred(get_precision(UInt32 )) == AbstractFloat
856+
@test @inferred(get_precision(UInt64 )) == AbstractFloat
857+
@test @inferred(get_precision(UInt128 )) == AbstractFloat
858+
@test @inferred(get_precision(Int8 )) == AbstractFloat
859+
@test @inferred(get_precision(Int16 )) == AbstractFloat
860+
@test @inferred(get_precision(Int32 )) == AbstractFloat
861+
@test @inferred(get_precision(Int64 )) == AbstractFloat
862+
@test @inferred(get_precision(Int128 )) == AbstractFloat
863+
@test @inferred(get_precision(BigInt )) == AbstractFloat
864+
@test @inferred(get_precision(Float16 )) == Float16
865+
@test @inferred(get_precision(Float32 )) == Float32
866+
@test @inferred(get_precision(Float64 )) == Float64
867+
@test @inferred(get_precision(BigFloat )) == BigFloat
868+
@test @inferred(get_precision(Complex{Int} )) == AbstractFloat
869+
@test @inferred(get_precision(Complex{Float16} )) == Float16
870+
@test @inferred(get_precision(Complex{Float32} )) == Float32
871+
@test @inferred(get_precision(Complex{Float64} )) == Float64
872+
@test @inferred(get_precision(Complex{BigFloat})) == BigFloat
873+
874+
@test @inferred(get_precision(:symbol )) == AbstractFloat
875+
@test @inferred(get_precision("string" )) == AbstractFloat
876+
@test @inferred(get_precision(true )) == AbstractFloat
877+
@test @inferred(get_precision(false )) == AbstractFloat
878+
@test @inferred(get_precision(3 )) == AbstractFloat
879+
@test @inferred(get_precision(0x03 )) == AbstractFloat
880+
@test @inferred(get_precision(3 )) == AbstractFloat
881+
@test @inferred(get_precision(big(3) )) == AbstractFloat
882+
@test @inferred(get_precision(3//2 )) == AbstractFloat
883+
@test @inferred(get_precision(π )) == AbstractFloat
884+
@test @inferred(get_precision(Float16(3) )) == Float16
885+
@test @inferred(get_precision(3.0f0 )) == Float32
886+
@test @inferred(get_precision(3.0 )) == Float64
887+
@test @inferred(get_precision(big(3.0) )) == BigFloat
888+
@test @inferred(get_precision(1 + 2im )) == AbstractFloat
889+
@test @inferred(get_precision(1.0f0 + 2.0f0im )) == Float32
890+
@test @inferred(get_precision(Complex{Float32} )) == Float32
891+
@test @inferred(get_precision(1.0 + 2.0im )) == Float64
892+
893+
A = ones(Bool, 2,3,4)
894+
@test @inferred(get_precision(A)) == AbstractFloat
895+
@test @inferred(get_precision(typeof(A))) == AbstractFloat
896+
897+
A = ComplexF32.([9+1im 2-3im 1; 0 7 1; 0 0 4])
898+
@test @inferred(get_precision(A)) == Float32
899+
@test @inferred(get_precision(typeof(A))) == Float32
900+
B = adjoint(A)
901+
@test @inferred(get_precision(B)) == Float32
902+
@test @inferred(get_precision(typeof(B))) == Float32
903+
B = Diagonal(A)
904+
@test @inferred(get_precision(B)) == Float32
905+
@test @inferred(get_precision(typeof(B))) == Float32
906+
B = Hermitian(A)
907+
@test @inferred(get_precision(B)) == Float32
908+
@test @inferred(get_precision(typeof(B))) == Float32
909+
B = qr(A)
910+
@test @inferred(get_precision(B)) == Float32
911+
@test @inferred(get_precision(typeof(B))) == Float32
912+
913+
end
804914

915+
@testset "adapt_precision($T, x)" for T in (AbstractFloat, Float16, Float32, Float64, BigFloat)
916+
if isconcretetype(T)
917+
@test T <: Precision
918+
end
919+
920+
S = isconcretetype(T) ? T : TypeUtils.default_precision
921+
922+
@test @inferred(adapt_precision(T, Symbol )) === Symbol
923+
@test @inferred(adapt_precision(T, String )) === String
924+
@test @inferred(adapt_precision(T, Bool )) === S
925+
@test @inferred(adapt_precision(T, UInt8 )) === S
926+
@test @inferred(adapt_precision(T, UInt16 )) === S
927+
@test @inferred(adapt_precision(T, UInt32 )) === S
928+
@test @inferred(adapt_precision(T, UInt64 )) === S
929+
@test @inferred(adapt_precision(T, UInt128 )) === S
930+
@test @inferred(adapt_precision(T, Int8 )) === S
931+
@test @inferred(adapt_precision(T, Int16 )) === S
932+
@test @inferred(adapt_precision(T, Int32 )) === S
933+
@test @inferred(adapt_precision(T, Int64 )) === S
934+
@test @inferred(adapt_precision(T, Int128 )) === S
935+
@test @inferred(adapt_precision(T, BigInt )) === S
936+
@test @inferred(adapt_precision(T, Float16 )) === S
937+
@test @inferred(adapt_precision(T, Float32 )) === S
938+
@test @inferred(adapt_precision(T, Float64 )) === S
939+
940+
str = "string"
941+
@test @inferred(adapt_precision(T, :symbol )) === :symbol
942+
@test @inferred(adapt_precision(T, str )) === str # same object
943+
@test @inferred(adapt_precision(T, true )) one(S)
944+
@test @inferred(adapt_precision(T, false )) zero(S)
945+
@test @inferred(adapt_precision(T, 0x03 )) S(3)
946+
@test @inferred(adapt_precision(T, 3 )) S(3)
947+
@test @inferred(adapt_precision(T, big(3) )) S(3)
948+
@test @inferred(adapt_precision(T, 3//2 )) S(3//2)
949+
@test @inferred(adapt_precision(T, π )) S(π)
950+
@test @inferred(adapt_precision(T, Float16(3))) S(3)
951+
@test @inferred(adapt_precision(T, 3.0f0 )) S(3)
952+
@test @inferred(adapt_precision(T, 3.0 )) S(3)
953+
@test @inferred(adapt_precision(T, big(3.0) )) S(3)
954+
955+
A = reshape(-3:20, 2,3,4)
956+
B = @inferred adapt_precision(T, A)
957+
@test eltype(B) === S
958+
@test axes(B) == axes(A)
959+
@test B == A
960+
961+
A = [9+1im 2-3im 1; 0 7 1; 0 0 4] # eltype(A) = Complex{Int}
962+
B = @inferred adapt_precision(T, A)
963+
@test eltype(B) === (eltype(A) <: Complex ? Complex{S} : S)
964+
@test axes(B) == axes(A)
965+
@test B == A
966+
967+
C = adjoint(A)
968+
B = @inferred adapt_precision(T, C)
969+
if real(eltype(C)) == T
970+
@test B === C # must be same object
971+
else
972+
@test typeof(B) <: Adjoint
973+
@test eltype(B) === (eltype(C) <: Complex ? Complex{S} : S)
974+
@test axes(B) == axes(C)
975+
@test B == C
976+
end
977+
978+
C = transpose(A)
979+
B = @inferred adapt_precision(T, C)
980+
if real(eltype(C)) == T
981+
@test B === C # must be same object
982+
else
983+
@test typeof(B) <: Transpose
984+
@test eltype(B) === (eltype(C) <: Complex ? Complex{S} : S)
985+
@test axes(B) == axes(C)
986+
@test B == C
987+
end
988+
989+
C = Diagonal(A)
990+
B = @inferred adapt_precision(T, C)
991+
if real(eltype(C)) == T
992+
@test B === C # must be same object
993+
else
994+
@test typeof(B) <: Diagonal
995+
@test eltype(B) === (eltype(C) <: Complex ? Complex{S} : S)
996+
@test axes(B) == axes(C)
997+
@test B == C
998+
end
999+
1000+
C = Hermitian(A)
1001+
B = @inferred adapt_precision(T, C)
1002+
if real(eltype(C)) == T
1003+
@test B === C # must be same object
1004+
else
1005+
@test typeof(B) <: Hermitian
1006+
@test eltype(B) === (eltype(C) <: Complex ? Complex{S} : S)
1007+
@test axes(B) == axes(C)
1008+
@test B == C
1009+
end
1010+
1011+
C = qr(A)
1012+
B = @inferred adapt_precision(T, C)
1013+
if real(eltype(C)) == T
1014+
@test B === C # must be same object
1015+
else
1016+
@test typeof(B) <: Factorization
1017+
@test eltype(B) === (eltype(C) <: Complex ? Complex{S} : S)
1018+
@test axes(B) == axes(C)
1019+
if T != Float16 && T != BigFloat
1020+
@test B C
1021+
end
1022+
end
1023+
1024+
C = svd(A)
1025+
B = @inferred adapt_precision(T, C)
1026+
if real(eltype(C)) == T
1027+
@test B === C # must be same object
1028+
else
1029+
@test typeof(B) <: Factorization
1030+
@test eltype(B) === (eltype(C) <: Complex ? Complex{S} : S)
1031+
@test axes(B) == axes(C)
1032+
@test B C
1033+
end
8051034
end
8061035

8071036
@testset "LinearAlgebra" begin

0 commit comments

Comments
 (0)