Skip to content

Commit 1fd398f

Browse files
authored
Check shapes in NamedDimsArray constructor (#42)
1 parent 49ac174 commit 1fd398f

File tree

6 files changed

+25
-2
lines changed

6 files changed

+25
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.4"
4+
version = "0.4.5"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
9+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
910
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
1011
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -27,6 +28,7 @@ NamedDimsArraysGradedUnitRangesExt = "GradedUnitRanges"
2728
Adapt = "4.1.1"
2829
ArrayLayouts = "1.11.0"
2930
BlockArrays = "1.3.0"
31+
Compat = "4.16.0"
3032
DerivableInterfaces = "0.3.7"
3133
FillArrays = "1.13.0"
3234
GradedUnitRanges = "0.1.3"

src/NamedDimsArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module NamedDimsArrays
22

33
export NamedDimsArray, aligndims, named, nameddimsarray
4+
using Compat: @compat
5+
@compat public to_nameddimsindices
46

57
include("isnamed.jl")
68
include("randname.jl")

src/abstractnameddimsarray.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ function to_nameddimsindices(a::AbstractArray, dims)
5959
return to_nameddimsindices(a, axes(a), dims)
6060
end
6161
function to_nameddimsindices(a::AbstractArray, axes, dims)
62-
return map((axis, dim) -> to_dimname(a, axis, dim), axes, dims)
62+
length(axes) == length(dims) || error("Number of dimensions don't match.")
63+
nameddimsindices = map((axis, dim) -> to_dimname(a, axis, dim), axes, dims)
64+
if any(size(a) .≠ length.(dename.(nameddimsindices)))
65+
error("Input dimensions don't match.")
66+
end
67+
return nameddimsindices
6368
end
6469
function to_dimname(a::AbstractArray, axis, dim::AbstractNamedArray)
6570
# TODO: Check `axis` and `dim` have the same shape?

src/nameddimsarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ struct NamedDimsArray{T,N,Parent<:AbstractArray{T,N},DimNames} <:
66
parent::Parent
77
nameddimsindices::DimNames
88
function NamedDimsArray(parent::AbstractArray, dims)
9+
# This checks the shapes of the inputs.
910
nameddimsindices = to_nameddimsindices(parent, dims)
1011
return new{eltype(parent),ndims(parent),typeof(parent),typeof(nameddimsindices)}(
1112
parent, nameddimsindices

test/test_basics.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,14 @@ using Test: @test, @test_throws, @testset
6868
@test dims(na, ("j", "i")) == (2, 1)
6969
@test na[1, 1] == a[1, 1]
7070

71+
@test_throws ErrorException NamedDimsArray(randn(4), namedoneto.((2, 2), ("i", "j")))
72+
@test_throws ErrorException NamedDimsArray(randn(2, 2), namedoneto.((2, 3), ("i", "j")))
73+
7174
a = randn(elt, 3, 4)
7275
na = nameddimsarray(a, ("i", "j"))
76+
i = namedoneto(3, "i")
77+
j = namedoneto(4, "j")
78+
ai, aj = axes(na)
7379
for na′ in (
7480
similar(na, Float32, (j, i)),
7581
similar(na, Float32, NaiveOrderedSet((j, i))),
@@ -87,6 +93,9 @@ using Test: @test, @test_throws, @testset
8793

8894
a = randn(elt, 3, 4)
8995
na = nameddimsarray(a, ("i", "j"))
96+
i = namedoneto(3, "i")
97+
j = namedoneto(4, "j")
98+
ai, aj = axes(na)
9099
for na′ in (
91100
similar(na, (j, i)),
92101
similar(na, NaiveOrderedSet((j, i))),

test/test_exports.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,9 @@ using NamedDimsArrays: NamedDimsArrays
22
using Test: @test, @testset
33
@testset "Test exports" begin
44
exports = [:NamedDimsArrays, :NamedDimsArray, :aligndims, :named, :nameddimsarray]
5+
publics = [:to_nameddimsindices]
6+
if VERSION v"1.11-"
7+
exports = [exports; publics]
8+
end
59
@test issetequal(names(NamedDimsArrays), exports)
610
end

0 commit comments

Comments
 (0)