Skip to content

Commit 6bab762

Browse files
authored
Diag for OneElement returns a OneElement (#383)
1 parent 05b76ad commit 6bab762

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

src/oneelement.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,14 @@ function triu(A::OneElementMatrix, k::Integer=0)
392392
OneElement(nzband < k ? zero(A.val) : A.val, A.ind, axes(A))
393393
end
394394

395+
# diag
396+
function diag(O::OneElementMatrix, k::Integer=0)
397+
Base.require_one_based_indexing(O)
398+
len = length(diagind(O, k))
399+
ind = O.ind[2] - O.ind[1] == k ? (k >= 0 ? O.ind[2] - k : O.ind[1] + k) : len + 1
400+
OneElement(getindex_value(O), ind, len)
401+
end
402+
395403
# broadcast
396404

397405
function broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::OneElement{<:Any,N}) where {N}

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2699,6 +2699,16 @@ end
26992699
B = OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2)))
27002700
@test repr(B) == "OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2)))"
27012701
end
2702+
2703+
@testset "diag" begin
2704+
@testset for sz in [(0,0), (0,1), (1,0), (1,1), (4,4), (4,6), (6,3)], ind in CartesianIndices(sz)
2705+
O = OneElement(4, Tuple(ind), sz)
2706+
@testset for k in -maximum(sz):maximum(sz)
2707+
@test diag(O, k) == diag(Array(O), k)
2708+
@test diag(O, k) isa OneElement{Int,1}
2709+
end
2710+
end
2711+
end
27022712
end
27032713

27042714
@testset "repeat" begin

0 commit comments

Comments
 (0)