Skip to content

Commit e7df543

Browse files
authored
Better permute[d]dims (#36)
1 parent 3118883 commit e7df543

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.12"
4+
version = "0.3.13"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -12,7 +12,7 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1212

1313
[compat]
1414
ArrayLayouts = "1.10.4"
15-
DerivableInterfaces = "0.5"
15+
DerivableInterfaces = "0.5.5"
1616
FillArrays = "1.13.0"
1717
LinearAlgebra = "1.10.0"
1818
SparseArraysBase = "0.7.2"

src/diagonalarray/diagonalarray.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,20 @@ end
165165

166166
# DiagonalArrays interface.
167167
diagview(a::DiagonalArray) = a.diag
168+
169+
# Special case for permutedims that is friendlier for immutable storage.
170+
function Base.permutedims(a::DiagonalArray, perm)
171+
((ndims(a) == length(perm)) && isperm(perm)) ||
172+
throw(ArgumentError("Not a valid permutation"))
173+
ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a))
174+
# Unlike `permutedims(::Diagonal, perm)`, we copy here.
175+
return DiagonalArray(copy(diagview(a)), ax_perm)
176+
end
177+
178+
function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
179+
((ndims(a) == length(perm)) && isperm(perm)) ||
180+
throw(ArgumentError("Not a valid permutation"))
181+
ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a))
182+
# Unlike `permutedims(::Diagonal, perm)`, we copy here.
183+
return DiagonalArray(diagview(a), ax_perm)
184+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
34
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
45
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/test_basics.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Test: @test, @testset, @test_broken, @inferred
2+
using DerivableInterfaces: permuteddims
23
using DiagonalArrays:
34
DiagonalArrays,
45
Delta,
@@ -102,6 +103,19 @@ using LinearAlgebra: Diagonal
102103
eltype(DiagonalArray{elt,2}(undef, Base.OneTo(2), Base.OneTo(2)))
103104
eltype(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2))))
104105
end
106+
@testset "permutedims" begin
107+
a = DiagonalArray(randn(elt, 2), (2, 3, 4))
108+
b = permutedims(a, (3, 1, 2))
109+
@test diagview(b) == diagview(a)
110+
@test diagview(b) diagview(a)
111+
@test size(b) === (4, 2, 3)
112+
end
113+
@testset "DerivableInterfaces.permuteddims" begin
114+
a = DiagonalArray(randn(elt, 2), (2, 3, 4))
115+
b = permuteddims(a, (3, 1, 2))
116+
@test diagview(b) diagview(a)
117+
@test size(b) === (4, 2, 3)
118+
end
105119
@testset "Matrix multiplication" begin
106120
a1 = DiagonalArray{elt}(undef, (2, 3))
107121
a1[1, 1] = 11

0 commit comments

Comments
 (0)