diff --git a/src/wrappers.jl b/src/wrappers.jl index 937328b..e1644dd 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -161,3 +161,20 @@ for type in (:Adjoint, :PermutedDimsArray, :ReshapedArray, :SubArray, :Transpose end end end + +using LinearAlgebra: Diagonal +@interface ::AbstractArrayInterface storedvalues(D::Diagonal) = LinearAlgebra.diag(D) +@interface ::AbstractArrayInterface eachstoredindex(D::Diagonal) = + LinearAlgebra.diagind(D, IndexCartesian()) +@interface ::AbstractArrayInterface isstored(D::Diagonal, i::Int, j::Int) = + i == j && Base.checkbounds(Bool, D, i, j) +@interface ::AbstractArrayInterface function getstoredindex(D::Diagonal, i::Int, j::Int) + return D.diag[i] +end +@interface ::AbstractArrayInterface function getunstoredindex(D::Diagonal, i::Int, j::Int) + return zero(eltype(D)) +end +@interface ::AbstractArrayInterface function setstoredindex!(D::Diagonal, v, i::Int, j::Int) + D.diag[i] = v + return D +end diff --git a/test/basics/test_diagonal.jl b/test/basics/test_diagonal.jl new file mode 100644 index 0000000..ada6f0c --- /dev/null +++ b/test/basics/test_diagonal.jl @@ -0,0 +1,25 @@ +using LinearAlgebra: Diagonal, diagind +using SparseArraysBase: + eachstoredindex, + getstoredindex, + getunstoredindex, + setstoredindex!, + isstored, + storedlength, + storedpairs, + storedvalues + +using Test: @test, @testset + +elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) + +@testset "Diagonal{$T}" for T in elts + L = 4 + D = Diagonal(rand(T, 4)) + @test storedlength(D) == 4 + @test eachstoredindex(D) == diagind(D, IndexCartesian()) + @test isstored(D, 2, 2) + @test getstoredindex(D, 2, 2) == D[2, 2] + @test !isstored(D, 2, 1) + @test getunstoredindex(D, 2, 2) == zero(T) +end