Skip to content

Commit f631a03

Browse files
authored
Better definition of equality, hash (#124)
1 parent 2c43066 commit f631a03

File tree

7 files changed

+75
-42
lines changed

7 files changed

+75
-42
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.8.3"
4+
version = "0.8.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -19,16 +19,19 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1919
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2020

2121
[weakdeps]
22+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
2223
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
2324
GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
2425
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
2526

2627
[extensions]
28+
NamedDimsArraysAbstractTreesExt = "AbstractTrees"
2729
NamedDimsArraysBlockArraysExt = "BlockArrays"
2830
NamedDimsArraysGradedArraysExt = "GradedArrays"
2931
NamedDimsArraysSparseArraysBaseExt = "SparseArraysBase"
3032

3133
[compat]
34+
AbstractTrees = "0.4.5"
3235
Adapt = "4.1.1"
3336
ArrayLayouts = "1.11"
3437
BlockArrays = "1.3"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module NamedDimsArraysAbstractTreesExt
2+
3+
using AbstractTrees: AbstractTrees
4+
using NamedDimsArrays: AbstractNamedDimsArray, dimnames
5+
6+
# Only print the dimension names when printing with `AbstractTrees.print_tree`.
7+
function AbstractTrees.printnode(io::IO, a::AbstractNamedDimsArray)
8+
show(IOContext(io, :compact => true, :limit => true), dimnames(a))
9+
return nothing
10+
end
11+
12+
end

src/abstractnameddimsarray.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,19 +400,31 @@ end
400400
# Base version ignores dimension names.
401401
# TODO: Use `mapreduce(isequal, &&, a1, a2)`?
402402
function Base.isequal(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray)
403-
return all(eachindex(a1, a2)) do I
404-
isequal(a1[I], a2[I])
405-
end
403+
issetequal(inds(a1), inds(a2)) || return false
404+
return isequal(unname(a1), unnamed(a2, inds(a1)))
406405
end
407406

408407
# Base version ignores dimension names.
409408
# TODO: Use `mapreduce(==, &&, a1, a2)`?
410409
# TODO: Handle `missing` values properly.
411410
function Base.:(==)(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray)
412411
issetequal(inds(a1), inds(a2)) || return false
413-
return all(eachindex(a1, a2)) do I
414-
a1[I] == a2[I]
412+
return unname(a1) == unnamed(a2, inds(a1))
413+
end
414+
415+
# Generalization of `Base.sort` to Tuples for Julia v1.10 compatibility.
416+
# TODO: Remove when we drop support for Julia v1.10.
417+
_sort(x; kwargs...) = sort(x; kwargs...)
418+
_sort(x::NTuple{N}; kwargs...) where {N} = NTuple{N}(sort(collect(x); kwargs...))
419+
420+
function Base.hash(a::AbstractNamedDimsArray, h::UInt64)
421+
h = hash(:NamedDimsArray, h)
422+
a′ = aligneddims(a, _sort(dimnames(a)))
423+
h = hash(dename(a′), h)
424+
for i in inds(a′)
425+
h = hash(i, h)
415426
end
427+
return h
416428
end
417429

418430
# Indexing.
@@ -674,6 +686,7 @@ function aligndims(a::AbstractArray, dims)
674686
return constructorof(typeof(a))(permutedims(dename(a), perm), new_inds)
675687
end
676688

689+
using DerivableInterfaces: permuteddims
677690
function aligneddims(a::AbstractArray, dims)
678691
new_inds = to_inds(a, dims)
679692
perm = getperm(inds(a), new_inds)
@@ -683,7 +696,7 @@ function aligneddims(a::AbstractArray, dims)
683696
),
684697
)
685698
return constructorof_nameddims(typeof(a))(
686-
PermutedDimsArray(dename(a), perm), new_inds
699+
permuteddims(dename(a), perm), new_inds
687700
)
688701
end
689702

src/naiveorderedset.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ Base.Broadcast.BroadcastStyle(s1::AbstractArrayStyle, s2::Style{NaiveOrderedSet}
2424
Base.Broadcast.broadcastable(s::NaiveOrderedSet) = s
2525
Base.to_shape(s::NaiveOrderedSet) = s
2626

27+
# Needed for functionality such as `CartesianIndices(::AbstractNamedDimsArray)`,
28+
# `pairs(::AbstractNamedDimsArray)`, etc.
29+
Base.CartesianIndices(s::NaiveOrderedSet) = CartesianIndices(values(s))
30+
2731
function Base.copy(
2832
bc::Broadcasted{Style{NaiveOrderedSet}, <:Any, <:Any, <:Tuple{<:NaiveOrderedSet}}
2933
)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
23
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
34
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
45
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -18,6 +19,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1819
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1920

2021
[compat]
22+
AbstractTrees = "0.4.5"
2123
Adapt = "4"
2224
Aqua = "0.8.9"
2325
BlockArrays = "1"

test/test_abstracttreesext.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using AbstractTrees: printnode
2+
using NamedDimsArrays: nameddims
3+
using Test: @test, @testset
4+
5+
@testset "AbstractTreesExt" begin
6+
a = randn(3, 4)
7+
na = nameddims(a, ("i", "j"))
8+
@test sprint(printnode, na) == """("i", "j")"""
9+
end

test/test_basics.jl

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,9 @@
11
using Combinatorics: Combinatorics
2-
using NamedDimsArrays:
3-
NamedDimsArrays,
4-
AbstractNamedDimsArray,
5-
AbstractNamedDimsMatrix,
6-
NaiveOrderedSet,
7-
Name,
8-
NameMismatch,
9-
NamedDimsCartesianIndex,
10-
NamedDimsCartesianIndices,
11-
NamedDimsArray,
12-
NamedDimsMatrix,
13-
aligndims,
14-
aligneddims,
15-
apply,
16-
dename,
17-
denamed,
18-
dim,
19-
dimnames,
20-
dims,
21-
fusednames,
22-
isnamed,
23-
mapinds,
24-
name,
25-
named,
26-
nameddims,
27-
inds,
28-
namedoneto,
29-
operator,
30-
product,
31-
replaceinds,
32-
setinds,
33-
state,
34-
unname,
35-
unnamed,
36-
@names
2+
using NamedDimsArrays: NamedDimsArrays, AbstractNamedDimsArray, AbstractNamedDimsMatrix,
3+
NaiveOrderedSet, Name, NameMismatch, NamedDimsCartesianIndex, NamedDimsCartesianIndices,
4+
NamedDimsArray, NamedDimsMatrix, aligndims, aligneddims, apply, dename, denamed, dim,
5+
dimnames, dims, fusednames, isnamed, mapinds, name, named, nameddims, inds, namedoneto,
6+
operator, product, replaceinds, setinds, state, unname, unnamed, @names
377
using Test: @test, @test_throws, @testset
388
using VectorInterface: scalartype
399

@@ -74,6 +44,26 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
7444
@test dims(na, ("j", "i")) == (2, 1)
7545
@test na[1, 1] == a[1, 1]
7646

47+
# equals (==)/isequal
48+
a = randn(elt, 3, 4)
49+
na = nameddims(a, ("i", "j"))
50+
@test na == na
51+
@test na == aligndims(na, ("j", "i"))
52+
@test isequal(na, na)
53+
@test isequal(na, aligndims(na, ("j", "i")))
54+
@test hash(na) == hash(aligndims(na, ("j", "i")))
55+
# Regression test that NamedDimsArrays
56+
# with different names are not equal (as opposed to
57+
# erroring).
58+
@test na nameddims(a, ("j", "k"))
59+
@test !isequal(na, nameddims(a, ("j", "k")))
60+
@test hash(na) hash(nameddims(a, ("j", "k")))
61+
62+
a = randn(elt, 2, 2)
63+
na = nameddims(a, ("i", "j"))
64+
@test CartesianIndices(na) == CartesianIndices(a)
65+
@test collect(pairs(na)) == (CartesianIndices(a) .=> a)
66+
7767
@test_throws ErrorException NamedDimsArray(randn(4), namedoneto.((2, 2), ("i", "j")))
7868
@test_throws ErrorException NamedDimsArray(randn(2, 2), namedoneto.((2, 3), ("i", "j")))
7969

0 commit comments

Comments
 (0)