Skip to content

Commit a2a98ad

Browse files
authored
Fix some incorrect dimension permutations in getindex/setindex! (#26)
1 parent 29c7d50 commit a2a98ad

File tree

4 files changed

+23
-9
lines changed

4 files changed

+23
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.10"
4+
version = "0.3.11"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractnameddimsarray.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,7 @@ function Base.getindex(
400400
a::AbstractNamedDimsArray, I1::AbstractNamedInteger, Irest::AbstractNamedInteger...
401401
)
402402
I = (I1, Irest...)
403-
# TODO: Check if this permuation should be inverted.
404-
perm = getperm(name.(nameddimsindices(a)), name.(I))
403+
perm = getperm(name.(I), name.(nameddimsindices(a)))
405404
# TODO: Throw a `NameMismatch` error.
406405
@assert isperm(perm)
407406
I = map(p -> I[p], perm)
@@ -446,8 +445,7 @@ function Base.setindex!(
446445
a::AbstractNamedDimsArray, value, I1::AbstractNamedInteger, Irest::AbstractNamedInteger...
447446
)
448447
I = flatten_namedinteger.((I1, Irest...))
449-
# TODO: Check if this permuation should be inverted.
450-
perm = getperm(name.(nameddimsindices(a)), name.(I))
448+
perm = getperm(name.(I), name.(nameddimsindices(a)))
451449
# TODO: Throw a `NameMismatch` error.
452450
@assert isperm(perm)
453451
I = map(p -> I[p], perm)
@@ -523,8 +521,7 @@ end
523521

524522
function Base.view(a::AbstractNamedDimsArray, I1::NamedViewIndex, Irest::NamedViewIndex...)
525523
I = (I1, Irest...)
526-
# TODO: Check if this permuation should be inverted.
527-
perm = getperm(name.(nameddimsindices(a)), name.(I))
524+
perm = getperm(name.(I), name.(nameddimsindices(a)))
528525
# TODO: Throw a `NameMismatch` error.
529526
@assert isperm(perm)
530527
I = map(p -> I[p], perm)
@@ -604,7 +601,6 @@ end
604601

605602
function aligndims(a::AbstractArray, dims)
606603
new_nameddimsindices = to_nameddimsindices(a, dims)
607-
# TODO: Check this permutation is correct (it may be the inverse of what we want).
608604
perm = Tuple(getperm(nameddimsindices(a), new_nameddimsindices))
609605
isperm(perm) || throw(
610606
NameMismatch(
@@ -616,7 +612,6 @@ end
616612

617613
function aligneddims(a::AbstractArray, dims)
618614
new_nameddimsindices = to_nameddimsindices(a, dims)
619-
# TODO: Check this permutation is correct (it may be the inverse of what we want).
620615
perm = getperm(nameddimsindices(a), new_nameddimsindices)
621616
isperm(perm) || throw(
622617
NameMismatch(

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
44
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
55
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
6+
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
89
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"

test/basics/test_basics.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Combinatorics: Combinatorics
12
using NamedDimsArrays:
23
NamedDimsArrays,
34
AbstractNamedDimsArray,
@@ -198,6 +199,23 @@ using Test: @test, @test_throws, @testset
198199
c = nameddims(Array{elt}(undef, 2, 3), (:i, :j))
199200
c .= a .+ 2 .* b
200201
@test dename(c, (:i, :j)) dename(a, (:i, :j)) + 2 * dename(b, (:i, :j))
202+
203+
# Regression test for proper permutations.
204+
a = nameddims(randn(elt, 2, 3, 4), (:i, :j, :k))
205+
I = (:i => 2, :j => 3, :k => 4)
206+
for I′ in Combinatorics.permutations(I)
207+
@test a[I′...] == a[2, 3, 4]
208+
a′ = copy(a)
209+
a′[I′...] = zero(Bool)
210+
@test iszero(a′[2, 3, 4])
211+
end
212+
I = (:i => 2, :j => 2:3, :k => 4)
213+
for I′ in Combinatorics.permutations(I)
214+
@test a[I′...] == a[2, 2:3, 4]
215+
## TODO: This is broken, investigate.
216+
## a′[I′...] = zeros(Bool, 2)
217+
## @test iszero(a′[2, 2:3, 4])
218+
end
201219
end
202220
@testset "Shorthand constructors (eltype=$elt)" for elt in (
203221
Float32, ComplexF32, Float64, ComplexF64

0 commit comments

Comments
 (0)