diff --git a/Project.toml b/Project.toml index ae4403d..9c70da8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.12" +version = "0.3.13" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -12,7 +12,7 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" [compat] ArrayLayouts = "1.10.4" -DerivableInterfaces = "0.5" +DerivableInterfaces = "0.5.5" FillArrays = "1.13.0" LinearAlgebra = "1.10.0" SparseArraysBase = "0.7.2" diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 1741923..f3518c5 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -165,3 +165,20 @@ end # DiagonalArrays interface. diagview(a::DiagonalArray) = a.diag + +# Special case for permutedims that is friendlier for immutable storage. +function Base.permutedims(a::DiagonalArray, perm) + ((ndims(a) == length(perm)) && isperm(perm)) || + throw(ArgumentError("Not a valid permutation")) + ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a)) + # Unlike `permutedims(::Diagonal, perm)`, we copy here. + return DiagonalArray(copy(diagview(a)), ax_perm) +end + +function DerivableInterfaces.permuteddims(a::DiagonalArray, perm) + ((ndims(a) == length(perm)) && isperm(perm)) || + throw(ArgumentError("Not a valid permutation")) + ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a)) + # Unlike `permutedims(::Diagonal, perm)`, we copy here. + return DiagonalArray(diagview(a), ax_perm) +end diff --git a/test/Project.toml b/test/Project.toml index 7ca368d..ebd0a29 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/test_basics.jl b/test/test_basics.jl index 52263f0..8ff3ca8 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,4 +1,5 @@ using Test: @test, @testset, @test_broken, @inferred +using DerivableInterfaces: permuteddims using DiagonalArrays: DiagonalArrays, Delta, @@ -102,6 +103,19 @@ using LinearAlgebra: Diagonal eltype(DiagonalArray{elt,2}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ eltype(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2)))) end + @testset "permutedims" begin + a = DiagonalArray(randn(elt, 2), (2, 3, 4)) + b = permutedims(a, (3, 1, 2)) + @test diagview(b) == diagview(a) + @test diagview(b) ≢ diagview(a) + @test size(b) === (4, 2, 3) + end + @testset "DerivableInterfaces.permuteddims" begin + a = DiagonalArray(randn(elt, 2), (2, 3, 4)) + b = permuteddims(a, (3, 1, 2)) + @test diagview(b) ≡ diagview(a) + @test size(b) === (4, 2, 3) + end @testset "Matrix multiplication" begin a1 = DiagonalArray{elt}(undef, (2, 3)) a1[1, 1] = 11