Skip to content

Commit b21c44e

Browse files
authored
Redesign to support slicing (#10)
1 parent a8caa48 commit b21c44e

19 files changed

+928
-300
lines changed

Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.0"
4+
version = "0.3.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
89
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
910
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -13,8 +14,16 @@ SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
1314
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1415
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1516

17+
[weakdeps]
18+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
19+
20+
[extensions]
21+
NamedDimsArraysBlockArraysExt = "BlockArrays"
22+
1623
[compat]
1724
Adapt = "4.1.1"
25+
ArrayLayouts = "1.11.0"
26+
BlockArrays = "1.3.0"
1827
BroadcastMapConversion = "0.1.2"
1928
Derive = "0.3.6"
2029
LinearAlgebra = "1.10"

README.md

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,67 @@ julia> Pkg.add("NamedDimsArrays")
3232
## Examples
3333

3434
````julia
35-
using NamedDimsArrays: aligndims, dename, dimnames, named
35+
using NamedDimsArrays: aligndims, dimnames, named, nameddimsindices, namedoneto, unname
3636
using TensorAlgebra: contract
37+
using Test: @test
3738

3839
# Named dimensions
39-
i = named(2, "i")
40-
j = named(2, "j")
41-
k = named(2, "k")
40+
i = namedoneto(2, "i")
41+
j = namedoneto(2, "j")
42+
k = namedoneto(2, "k")
4243

4344
# Arrays with named dimensions
44-
na1 = randn(i, j)
45-
na2 = randn(j, k)
45+
a1 = randn(i, j)
46+
a2 = randn(j, k)
4647

47-
@show dimnames(na1) == ("i", "j")
48+
@test dimnames(a1) == ("i", "j")
49+
@test nameddimsindices(a1) == (i, j)
50+
@test axes(a1) == (named(1:2, i), named(1:2, j))
51+
@test size(a1) == (named(2, i), named(2, j))
4852

4953
# Indexing
50-
@show na1[j => 2, i => 1] == na1[1, 2]
54+
@test a1[j => 2, i => 1] == a1[1, 2]
55+
@test a1[j[2], i[1]] == a1[1, 2]
5156

5257
# Tensor contraction
53-
na_dest = contract(na1, na2)
58+
a_dest = contract(a1, a2)
5459

55-
@show issetequal(dimnames(na_dest), ("i", "k"))
56-
# `dename` removes the names and returns an `Array`
57-
@show dename(na_dest, (i, k)) dename(na1) * dename(na2)
60+
@test issetequal(nameddimsindices(a_dest), (i, k))
61+
# `unname` removes the names and returns an `Array`
62+
@test unname(a_dest, (i, k)) unname(a1, (i, j)) * unname(a2, (j, k))
5863

5964
# Permute dimensions (like `ITensors.permute`)
60-
na1 = aligndims(na1, (j, i))
61-
@show na1[i => 1, j => 2] == na1[2, 1]
65+
a1′ = aligndims(a1, (j, i))
66+
@test a1′[i => 1, j => 2] == a1[i => 1, j => 2]
67+
@test a1′[i[1], j[2]] == a1[i[1], j[2]]
68+
69+
# Contiguous slicing
70+
b1 = a1[i => 1:2, j => 1:1]
71+
@test b1 == a1[i[1:2], j[1:1]]
72+
73+
b2 = a2[j => 1:1, k => 1:2]
74+
@test b2 == a2[j[1:1], k[1:2]]
75+
76+
@test nameddimsindices(b1) == (i[1:2], j[1:1])
77+
@test nameddimsindices(b2) == (j[1:1], k[1:2])
78+
79+
b_dest = contract(b1, b2)
80+
81+
@test issetequal(nameddimsindices(b_dest), (i, k))
82+
83+
# Non-contiguous slicing
84+
c1 = a1[i[[2, 1]], j[[2, 1]]]
85+
@test nameddimsindices(c1) == (i[[2, 1]], j[[2, 1]])
86+
@test unname(c1, (i[[2, 1]], j[[2, 1]])) == unname(a1, (i, j))[[2, 1], [2, 1]]
87+
@test c1[i[2], j[1]] == a1[i[2], j[1]]
88+
@test c1[2, 1] == a1[1, 2]
89+
90+
a1[i[[2, 1]], j[[2, 1]]] = [22 21; 12 11]
91+
@test a1[i[1], j[1]] == 11
92+
93+
x = randn(i[1:2], j[2:2])
94+
a1[i[1:2], j[2:2]] = x
95+
@test a1[i[1], j[2]] == x[i[1], j[2]]
6296
````
6397

6498
---

TODO.md

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
1+
- Define `@align`/`@aligned` such that:
2+
```julia
3+
i = namedoneto(2, "i")
4+
j = namedoneto(2, "j")
5+
a = randn(i, j)
6+
@align a[j, i]
7+
@aligned a[j, i]
8+
```
9+
aligns the dimensions (currently `a[j, i]` doesn't align the dimensions).
10+
It could be written in terms of `align_getindex`/`align_view`.
111
- `svd`, `eigen` (including tensor versions)
2-
- `reshape`, `vec`
3-
- `swapdimnames`
4-
- `mapdimnames(f, a::AbstractNamedDimsArray)` (rename `replacedimnames(f, a)` to `mapdimnames(f, a)`, or have both?)
12+
- `reshape`, `vec`, including fused dimension names.
13+
- Dimension name set logic, i.e. `setdiffnameddimsindices(a::AbstractNamedDimsArray, b::AbstractNamedDimsArray)`, etc.
14+
- `swapnameddimsindices` (written in terms of `mapnameddimsindices`/`replacenameddimsindices`).
15+
- `mapnameddimsindices(f, a::AbstractNamedDimsArray)` (rename `replacenameddimsindices(f, a)` to `mapnameddimsindices(f, a)`, or have both?)
516
- `cat` (define `CatName` as a combination of the input names?).
617
- `canonize`/`flatten_array_wrappers` (https://github.com/mcabbott/NamedPlus.jl/blob/v0.0.5/src/permute.jl#L207)
7-
- `nameddims(PermutedDimsArray(a, perm), dimnames)` -> `nameddims(a, dimnames[invperm(perm)])`
8-
- `nameddims(transpose(a), dimnames)` -> `nameddims(a, reverse(dimnames))`
9-
- `Transpose(nameddims(a, dimnames))` -> `nameddims(a, reverse(dimnames))`
18+
- `nameddims(PermutedDimsArray(a, perm), nameddimsindices)` -> `nameddims(a, nameddimsindices[invperm(perm)])`
19+
- `nameddims(transpose(a), nameddimsindices)` -> `nameddims(a, reverse(nameddimsindices))`
20+
- `Transpose(nameddims(a, nameddimsindices))` -> `nameddims(a, reverse(nameddimsindices))`
1021
- etc.
1122
- `MappedName(old_name, name)`, acts like `Name(name)` but keeps track of the old name.
12-
- `namedmap(a, ::Pair...)`: `namedmap(named(randn(2, 2, 2, 2), i, j, k, l), i => k, j => l)`
23+
- `nameddimsmap(a, ::Pair...)`: `namedmap(named(randn(2, 2, 2, 2), i, j, k, l), i => k, j => l)`
1324
represents that the names map back and forth to each other for the sake of `transpose`,
1425
`tr`, `eigen`, etc. Operators are generally `namedmap(named(randn(2, 2), i, i'), i => i')`.
1526
- `prime(:i) = PrimedName(:i)`, `prime(:i, 2) = PrimedName(:i, 2)`, `prime(prime(:i)) = PrimedName(:i, 2)`,
1627
`Name(:i)' = prime(:i)`, etc.
17-
- `transpose`/`adjoint` based on `swapdimnames` and `MappedName(old_name, new_name)`.
28+
- Also `prime(f, a::AbstractNamedDimsArray)` where `f` is a filter function to determine
29+
which dimensions to filter.
30+
- `transpose`/`adjoint` based on `swapnameddimsindices` and `MappedName(old_name, new_name)`.
1831
- `adjoint` could make use of a lazy `ConjArray`.
1932
- `transpose(a, dimname1 => dimname1′, dimname2 => dimname2′)` like `https://github.com/mcabbott/NamedPlus.jl`.
2033
- Same as `replacedims(a, dimname1 => dimname1′, dimname1′ => dimname1, dimname2 => dimname2′, dimname2′ => dimname2)`.
@@ -23,4 +36,5 @@
2336
- Slicing: `nameddims(a, "i", "j")[1:2, 1:2] = nameddims(a[1:2, 1:2], Name(named(1:2, "i")), Name(named(1:2, "j")))`, i.e.
2437
the parent gets sliced and the new dimensions names are the named slice.
2538
- Should `NamedDimsArray` store the named axes rather than just the dimension names?
26-
- Should `NamedDimsArray` have special axes types so that `axes(nameddims(a, "i", "j")) == axes(nameddims(a', "j", "i"))`?
39+
- Should `NamedDimsArray` have special axes types so that `axes(nameddims(a, "i", "j")) == axes(nameddims(a', "j", "i"))`,
40+
i.e. equality is based on `issetequal` and not dependent on the ordering of the dimensions?

examples/README.jl

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,30 +37,64 @@ julia> Pkg.add("NamedDimsArrays")
3737

3838
# ## Examples
3939

40-
using NamedDimsArrays: aligndims, dename, dimnames, named
40+
using NamedDimsArrays: aligndims, dimnames, named, nameddimsindices, namedoneto, unname
4141
using TensorAlgebra: contract
42+
using Test: @test
4243

4344
## Named dimensions
44-
i = named(2, "i")
45-
j = named(2, "j")
46-
k = named(2, "k")
45+
i = namedoneto(2, "i")
46+
j = namedoneto(2, "j")
47+
k = namedoneto(2, "k")
4748

4849
## Arrays with named dimensions
49-
na1 = randn(i, j)
50-
na2 = randn(j, k)
50+
a1 = randn(i, j)
51+
a2 = randn(j, k)
5152

52-
@show dimnames(na1) == ("i", "j")
53+
@test dimnames(a1) == ("i", "j")
54+
@test nameddimsindices(a1) == (i, j)
55+
@test axes(a1) == (named(1:2, i), named(1:2, j))
56+
@test size(a1) == (named(2, i), named(2, j))
5357

5458
## Indexing
55-
@show na1[j => 2, i => 1] == na1[1, 2]
59+
@test a1[j => 2, i => 1] == a1[1, 2]
60+
@test a1[j[2], i[1]] == a1[1, 2]
5661

5762
## Tensor contraction
58-
na_dest = contract(na1, na2)
63+
a_dest = contract(a1, a2)
5964

60-
@show issetequal(dimnames(na_dest), ("i", "k"))
61-
## `dename` removes the names and returns an `Array`
62-
@show dename(na_dest, (i, k)) dename(na1) * dename(na2)
65+
@test issetequal(nameddimsindices(a_dest), (i, k))
66+
## `unname` removes the names and returns an `Array`
67+
@test unname(a_dest, (i, k)) unname(a1, (i, j)) * unname(a2, (j, k))
6368

6469
## Permute dimensions (like `ITensors.permute`)
65-
na1 = aligndims(na1, (j, i))
66-
@show na1[i => 1, j => 2] == na1[2, 1]
70+
a1′ = aligndims(a1, (j, i))
71+
@test a1′[i => 1, j => 2] == a1[i => 1, j => 2]
72+
@test a1′[i[1], j[2]] == a1[i[1], j[2]]
73+
74+
## Contiguous slicing
75+
b1 = a1[i => 1:2, j => 1:1]
76+
@test b1 == a1[i[1:2], j[1:1]]
77+
78+
b2 = a2[j => 1:1, k => 1:2]
79+
@test b2 == a2[j[1:1], k[1:2]]
80+
81+
@test nameddimsindices(b1) == (i[1:2], j[1:1])
82+
@test nameddimsindices(b2) == (j[1:1], k[1:2])
83+
84+
b_dest = contract(b1, b2)
85+
86+
@test issetequal(nameddimsindices(b_dest), (i, k))
87+
88+
## Non-contiguous slicing
89+
c1 = a1[i[[2, 1]], j[[2, 1]]]
90+
@test nameddimsindices(c1) == (i[[2, 1]], j[[2, 1]])
91+
@test unname(c1, (i[[2, 1]], j[[2, 1]])) == unname(a1, (i, j))[[2, 1], [2, 1]]
92+
@test c1[i[2], j[1]] == a1[i[2], j[1]]
93+
@test c1[2, 1] == a1[1, 2]
94+
95+
a1[i[[2, 1]], j[[2, 1]]] = [22 21; 12 11]
96+
@test a1[i[1], j[1]] == 11
97+
98+
x = randn(i[1:2], j[2:2])
99+
a1[i[1:2], j[2:2]] = x
100+
@test a1[i[1], j[2]] == x[i[1], j[2]]
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module NamedDimsArraysBlockArraysExt
2+
using ArrayLayouts: ArrayLayouts
3+
using BlockArrays: Block, BlockRange
4+
using NamedDimsArrays:
5+
AbstractNamedDimsArray,
6+
AbstractNamedUnitRange,
7+
named_getindex,
8+
nameddims_getindex,
9+
nameddims_view
10+
11+
function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::Block{1})
12+
# TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead.
13+
return named_getindex(r, I)
14+
end
15+
16+
function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::BlockRange{1})
17+
# TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead.
18+
return named_getindex(r, I)
19+
end
20+
21+
const BlockIndex{N} = Union{Block{N},BlockRange{N},AbstractVector{<:Block{N}}}
22+
23+
function Base.view(a::AbstractNamedDimsArray, I1::Block{1}, Irest::BlockIndex{1}...)
24+
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
25+
return nameddims_view(a, I1, Irest...)
26+
end
27+
28+
function Base.view(a::AbstractNamedDimsArray, I::Block)
29+
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
30+
return nameddims_view(a, Tuple(I)...)
31+
end
32+
33+
function Base.view(a::AbstractNamedDimsArray, I1::BlockIndex{1}, Irest::BlockIndex{1}...)
34+
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
35+
return nameddims_view(a, I1, Irest...)
36+
end
37+
38+
# Fix ambiguity error.
39+
function Base.getindex(
40+
a::AbstractNamedDimsArray, I1::BlockRange{1}, Irest::BlockRange{1}...
41+
)
42+
return ArrayLayouts.layout_getindex(a, I1, Irest...)
43+
end
44+
45+
end

src/NamedDimsArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ include("isnamed.jl")
44
include("randname.jl")
55
include("abstractnamedinteger.jl")
66
include("namedinteger.jl")
7+
include("abstractnamedarray.jl")
8+
include("namedarray.jl")
79
include("abstractnamedunitrange.jl")
810
include("namedunitrange.jl")
911
include("abstractnameddimsarray.jl")

src/abstractnamedarray.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using TypeParameterAccessors: unspecify_type_parameters
2+
3+
abstract type AbstractNamedArray{T,N,Value<:AbstractArray,Name} <: AbstractArray{T,N} end
4+
5+
const AbstractNamedVector{T,Value<:AbstractVector,Name} = AbstractNamedArray{T,1,Value,Name}
6+
const AbstractNamedMatrix{T,Value<:AbstractVector,Name} = AbstractNamedArray{T,2,Value,Name}
7+
8+
# Minimal interface.
9+
dename(a::AbstractNamedArray) = throw(MethodError(dename, Tuple{typeof(a)}))
10+
name(a::AbstractNamedArray) = throw(MethodError(name, Tuple{typeof(a)}))
11+
12+
# This can be customized to output different named integer types,
13+
# such as `namedarray(a::AbstractArray, name::IndexName) = Index(a, name)`.
14+
namedarray(a::AbstractArray, name) = NamedArray(a, name)
15+
16+
# Shorthand.
17+
named(a::AbstractArray, name) = namedarray(a, name)
18+
19+
# Derived interface.
20+
# TODO: Use `Accessors.@set`?
21+
setname(a::AbstractNamedArray, name) = namedarray(dename(a), name)
22+
23+
# TODO: Use `TypeParameterAccessors`.
24+
denametype(::Type{<:AbstractNamedArray{<:Any,<:Any,Value}}) where {Value} = Value
25+
nametype(::Type{<:AbstractNamedArray{<:Any,<:Any,<:Any,Name}}) where {Name} = Name
26+
27+
# Traits.
28+
isnamed(::Type{<:AbstractNamedArray}) = true
29+
30+
# TODO: Should they also have the same base type?
31+
function Base.:(==)(a1::AbstractNamedArray, a2::AbstractNamedArray)
32+
return name(a1) == name(a2) && dename(a1) == dename(a2)
33+
end
34+
function Base.hash(a::AbstractNamedArray, h::UInt)
35+
h = hash(Symbol(unspecify_type_parameters(typeof(a))), h)
36+
# TODO: Double check how this is handling blocking/sector information.
37+
h = hash(dename(a), h)
38+
return hash(name(a), h)
39+
end
40+
41+
named_getindex(a::AbstractArray, I...) = named(getindex(dename(a), I...), name(a))
42+
43+
# Array funcionality.
44+
Base.size(a::AbstractNamedArray) = map(s -> named(s, name(a)), size(dename(a)))
45+
Base.axes(a::AbstractNamedArray) = map(s -> named(s, name(a)), axes(dename(a)))
46+
Base.eachindex(a::AbstractNamedArray) = eachindex(dename(a))
47+
function Base.getindex(a::AbstractNamedArray{<:Any,N}, I::Vararg{Int,N}) where {N}
48+
return named_getindex(a, I...)
49+
end
50+
function Base.getindex(a::AbstractNamedArray, I::Int)
51+
return named_getindex(a, I)
52+
end
53+
Base.isempty(a::AbstractNamedArray) = isempty(dename(a))
54+
55+
## function Base.AbstractArray{Int}(a::AbstractNamedArray)
56+
## return AbstractArray{Int}(dename(a))
57+
## end
58+
##
59+
## Base.iterate(a::AbstractNamedArray) = isempty(a) ? nothing : (first(a), first(a))
60+
## function Base.iterate(a::AbstractNamedArray, i)
61+
## i == last(a) && return nothing
62+
## next = named(dename(i) + dename(step(a)), name(a))
63+
## return (next, next)
64+
## end
65+
66+
function randname(ang::AbstractRNG, a::AbstractNamedArray)
67+
return named(dename(a), randname(name(a)))
68+
end
69+
70+
function Base.show(io::IO, a::AbstractNamedArray)
71+
print(io, "named(", dename(a), ", ", repr(name(a)), ")")
72+
return nothing
73+
end
74+
function Base.show(io::IO, mime::MIME"text/plain", a::AbstractNamedArray)
75+
print(io, "named(\n")
76+
show(io, mime, dename(a))
77+
print(io, ",\n ", repr(name(a)), ")")
78+
return nothing
79+
end

0 commit comments

Comments
 (0)