Skip to content

Commit 90ef3ce

Browse files
authored
Specialize sqrt and cbrt (#379)
* Sqrt for real matrices * Add cbrt * Tests for negative and complex values * Tests for zeros * Compare with dense * Test cbrt only on recent julia versions
1 parent 5b31642 commit 90ef3ce

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

src/FillArrays.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,32 @@ fillsimilar(a::Ones{T}, axes...) where T = Ones{T}(axes...)
374374
fillsimilar(a::Zeros{T}, axes...) where T = Zeros{T}(axes...)
375375
fillsimilar(a::AbstractFill, axes...) = Fill(getindex_value(a), axes...)
376376

377+
# functions
378+
function Base.sqrt(a::AbstractFillMatrix{<:Union{Real, Complex}})
379+
Base.require_one_based_indexing(a)
380+
size(a,1) == size(a,2) || throw(DimensionMismatch("matrix is not square: dimensions are $(size(a))"))
381+
_sqrt(a)
382+
end
383+
_sqrt(a::AbstractZerosMatrix) = float(a)
384+
function _sqrt(a::AbstractFillMatrix)
385+
n = size(a,1)
386+
n == 0 && return float(a)
387+
v = getindex_value(a)
388+
Fill((v/n), axes(a))
389+
end
390+
function Base.cbrt(a::AbstractFillMatrix{<:Real})
391+
Base.require_one_based_indexing(a)
392+
size(a,1) == size(a,2) || throw(DimensionMismatch("matrix is not square: dimensions are $(size(a))"))
393+
_cbrt(a)
394+
end
395+
_cbrt(a::AbstractZerosMatrix) = float(a)
396+
function _cbrt(a::AbstractFillMatrix)
397+
n = size(a,1)
398+
n == 0 && return float(a)
399+
v = getindex_value(a)
400+
Fill(cbrt(v)/cbrt(n)^2, axes(a))
401+
end
402+
377403
struct RectDiagonal{T,V<:AbstractVector{T},Axes<:Tuple{Vararg{AbstractUnitRange,2}}} <: AbstractMatrix{T}
378404
diag::V
379405
axes::Axes

test/runtests.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2959,3 +2959,32 @@ end
29592959
@test triu(Z, 2) === Z
29602960
@test tril(Z, 2) === Z
29612961
end
2962+
2963+
@testset "sqrt/cbrt" begin
2964+
F = Fill(4, 4, 4)
2965+
A = Array(F)
2966+
@test sqrt(F) sqrt(A) rtol=3e-8
2967+
@test sqrt(F)^2 F
2968+
F = Fill(4+4im, 4, 4)
2969+
A = Array(F)
2970+
@test sqrt(F) sqrt(A) rtol=1e-8
2971+
@test sqrt(F)^2 F
2972+
F = Fill(-4, 4, 4)
2973+
A = Array(F)
2974+
if VERSION >= v"1.11.0-rc3"
2975+
@test cbrt(F) cbrt(A) rtol=1e-5
2976+
end
2977+
@test cbrt(F)^3 F
2978+
2979+
# avoid overflow
2980+
F = Fill(4, typemax(Int), typemax(Int))
2981+
@test sqrt(F)^2 F
2982+
@test cbrt(F)^3 F
2983+
2984+
# zeros
2985+
F = Zeros(4, 4)
2986+
A = Array(F)
2987+
@test sqrt(F) sqrt(A) atol=1e-14
2988+
@test sqrt(F)^2 == F
2989+
@test cbrt(F)^3 == F
2990+
end

0 commit comments

Comments
 (0)