@@ -140,7 +140,7 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange)
140140end
141141
142142function Base. copy (a:: AbstractNamedDimsArray )
143- return nameddimsarraytype (a )(copy (dename (a)), nameddimsindices (a))
143+ return constructorof ( typeof (a) )(copy (dename (a)), nameddimsindices (a))
144144end
145145
146146const NamedDimsIndices = Union{
@@ -181,9 +181,11 @@ Base.values(s::NaiveOrderedSet) = s.values
181181Base. Tuple (s:: NaiveOrderedSet ) = Tuple (values (s))
182182Base. length (s:: NaiveOrderedSet ) = length (values (s))
183183Base. axes (s:: NaiveOrderedSet ) = axes (values (s))
184+ Base. keys (s:: NaiveOrderedSet ) = Base. OneTo (length (s))
184185Base.:(== )(s1:: NaiveOrderedSet , s2:: NaiveOrderedSet ) = issetequal (values (s1), values (s2))
185186Base. iterate (s:: NaiveOrderedSet , args... ) = iterate (values (s), args... )
186187Base. getindex (s:: NaiveOrderedSet , I:: Int ) = values (s)[I]
188+ Base. get (s:: NaiveOrderedSet , I:: Integer , default) = get (values (s), I, default)
187189Base. invperm (s:: NaiveOrderedSet ) = NaiveOrderedSet (invperm (values (s)))
188190Base. Broadcast. _axes (:: Broadcasted , axes:: NaiveOrderedSet ) = axes
189191Base. Broadcast. BroadcastStyle (:: Type{<:NaiveOrderedSet} ) = Style {NaiveOrderedSet} ()
@@ -210,6 +212,10 @@ function Base.size(a::AbstractNamedDimsArray)
210212 return NaiveOrderedSet (map (named, size (dename (a)), nameddimsindices (a)))
211213end
212214
215+ function Base. length (a:: AbstractNamedDimsArray )
216+ return prod (size (a); init= 1 )
217+ end
218+
213219# Circumvent issue when ndims isn't known at compile time.
214220function Base. axes (a:: AbstractNamedDimsArray , d)
215221 return d <= ndims (a) ? axes (a)[d] : OneTo (1 )
@@ -233,17 +239,20 @@ to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims)
233239to_nameddimsaxis (ax:: NamedDimsAxis ) = ax
234240to_nameddimsaxis (I:: NamedDimsIndices ) = named (dename (only (axes (I))), I)
235241
236- nameddimsarraytype (a:: AbstractNamedDimsArray ) = nameddimsarraytype (typeof (a))
237- nameddimsarraytype (a:: Type{<:AbstractNamedDimsArray} ) = unspecify_type_parameters (a)
242+ # Interface inspired by [ConstructionBase.constructorof](https://github.com/JuliaObjects/ConstructionBase.jl).
243+ constructorof (type:: Type{<:AbstractArray} ) = unspecify_type_parameters (type)
244+
245+ constructorof_nameddims (type:: Type{<:AbstractNamedDimsArray} ) = constructorof (type)
246+ constructorof_nameddims (type:: Type{<:AbstractArray} ) = NamedDimsArray
238247
239248function similar_nameddims (a:: AbstractNamedDimsArray , elt:: Type , inds)
240249 ax = to_nameddimsaxes (inds)
241- return nameddimsarraytype (a )(similar (dename (a), elt, dename .(Tuple (ax))), name .(ax))
250+ return constructorof ( typeof (a) )(similar (dename (a), elt, dename .(Tuple (ax))), name .(ax))
242251end
243252
244253function similar_nameddims (a:: AbstractArray , elt:: Type , inds)
245254 ax = to_nameddimsaxes (inds)
246- return nameddims (similar (a, elt, dename .(Tuple (ax))), name .(ax))
255+ return constructorof_nameddims ( typeof (a)) (similar (a, elt, dename .(Tuple (ax))), name .(ax))
247256end
248257
249258# Base.similar gets the eltype at compile time.
@@ -262,7 +271,7 @@ function Base.similar(a::AbstractArray, elt::Type, inds::NaiveOrderedSet)
262271end
263272
264273function setnameddimsindices (a:: AbstractNamedDimsArray , nameddimsindices)
265- return nameddimsarraytype (a )(dename (a), nameddimsindices)
274+ return constructorof ( typeof (a) )(dename (a), nameddimsindices)
266275end
267276function replacenameddimsindices (f, a:: AbstractNamedDimsArray )
268277 return setnameddimsindices (a, replace (f, nameddimsindices (a)))
@@ -419,10 +428,18 @@ function Base.setindex!(a::AbstractNamedDimsArray, value, I::CartesianIndex)
419428 setindex! (a, value, to_indices (a, (I,))... )
420429 return a
421430end
431+
432+ function flatten_namedinteger (i:: AbstractNamedInteger )
433+ if name (i) isa Union{AbstractNamedUnitRange,AbstractNamedArray}
434+ return name (i)[dename (i)]
435+ end
436+ return i
437+ end
438+
422439function Base. setindex! (
423440 a:: AbstractNamedDimsArray , value, I1:: AbstractNamedInteger , Irest:: AbstractNamedInteger...
424441)
425- I = ( I1, Irest... )
442+ I = flatten_namedinteger .(( I1, Irest... ) )
426443 # TODO : Check if this permuation should be inverted.
427444 perm = getperm (name .(nameddimsindices (a)), name .(I))
428445 # TODO : Throw a `NameMismatch` error.
@@ -510,7 +527,9 @@ function Base.view(a::AbstractNamedDimsArray, I1::NamedViewIndex, Irest::NamedVi
510527 subinds = map (nameddimsindices (a), I) do dimname, i
511528 return checked_indexin (dename (i), dename (dimname))
512529 end
513- return nameddims (view (dename (a), subinds... ), sub_nameddimsindices)
530+ return constructorof_nameddims (typeof (a))(
531+ view (dename (a), subinds... ), sub_nameddimsindices
532+ )
514533end
515534
516535function Base. getindex (
@@ -522,22 +541,22 @@ end
522541# Repeated definition of `Base.ViewIndex`.
523542const ViewIndex = Union{Real,AbstractArray}
524543
525- function nameddims_view (a:: AbstractArray , I... )
544+ function view_nameddims (a:: AbstractArray , I... )
526545 sub_dims = filter (dim -> ! (I[dim] isa Real), ntuple (identity, ndims (a)))
527546 sub_nameddimsindices = map (dim -> nameddimsindices (a, dim)[I[dim]], sub_dims)
528- return nameddims (view (dename (a), I... ), sub_nameddimsindices)
547+ return constructorof ( typeof (a)) (view (dename (a), I... ), sub_nameddimsindices)
529548end
530549
531550function Base. view (a:: AbstractNamedDimsArray , I:: ViewIndex... )
532- return nameddims_view (a, I... )
551+ return view_nameddims (a, I... )
533552end
534553
535- function nameddims_getindex (a:: AbstractArray , I... )
554+ function getindex_nameddims (a:: AbstractArray , I... )
536555 return copy (view (a, I... ))
537556end
538557
539558function Base. getindex (a:: AbstractNamedDimsArray , I:: ViewIndex... )
540- return nameddims_getindex (a, I... )
559+ return getindex_nameddims (a, I... )
541560end
542561
543562function Base. setindex! (
@@ -556,7 +575,7 @@ function Base.setindex!(
556575 Irest:: NamedViewIndex... ,
557576)
558577 I = (I1, Irest... )
559- setindex! (a, nameddimsarraytype (a )(value, I), I... )
578+ setindex! (a, constructorof ( typeof (a) )(value, I), I... )
560579 return a
561580end
562581function Base. setindex! (
@@ -580,13 +599,13 @@ end
580599function aligndims (a:: AbstractArray , dims)
581600 new_nameddimsindices = to_nameddimsindices (a, dims)
582601 # TODO : Check this permutation is correct (it may be the inverse of what we want).
583- perm = getperm (nameddimsindices (a), new_nameddimsindices)
602+ perm = Tuple ( getperm (nameddimsindices (a), new_nameddimsindices) )
584603 isperm (perm) || throw (
585604 NameMismatch (
586605 " Dimension name mismatch $(nameddimsindices (a)) , $(new_nameddimsindices) ."
587606 ),
588607 )
589- return nameddimsarraytype (a )(permutedims (dename (a), perm), new_nameddimsindices)
608+ return constructorof ( typeof (a) )(permutedims (dename (a), perm), new_nameddimsindices)
590609end
591610
592611function aligneddims (a:: AbstractArray , dims)
@@ -598,7 +617,9 @@ function aligneddims(a::AbstractArray, dims)
598617 " Dimension name mismatch $(nameddimsindices (a)) , $(new_nameddimsindices) ."
599618 ),
600619 )
601- return nameddimsarraytype (a)(PermutedDimsArray (dename (a), perm), new_nameddimsindices)
620+ return constructorof_nameddims (typeof (a))(
621+ PermutedDimsArray (dename (a), perm), new_nameddimsindices
622+ )
602623end
603624
604625# Convenient constructors
@@ -711,16 +732,17 @@ using Base.Broadcast:
711732 broadcasted,
712733 check_broadcast_shape,
713734 combine_axes
714- using MapBroadcast: Mapped, mapped
735+ using MapBroadcast: MapBroadcast, Mapped, mapped, tile
715736
716737abstract type AbstractNamedDimsArrayStyle{N} <: AbstractArrayStyle{N} end
717738
718- struct NamedDimsArrayStyle{N} <: AbstractNamedDimsArrayStyle{N} end
719- NamedDimsArrayStyle (:: Val{N} ) where {N} = NamedDimsArrayStyle {N} ()
720- NamedDimsArrayStyle {M} (:: Val{N} ) where {M,N} = NamedDimsArrayStyle {N} ()
739+ struct NamedDimsArrayStyle{N,NDA} <: AbstractNamedDimsArrayStyle{N} end
740+ NamedDimsArrayStyle (:: Val{N} ) where {N} = NamedDimsArrayStyle {N,NamedDimsArray} ()
741+ NamedDimsArrayStyle {M} (:: Val{N} ) where {M,N} = NamedDimsArrayStyle {N,NamedDimsArray} ()
742+ NamedDimsArrayStyle {M,NDA} (:: Val{N} ) where {M,N,NDA} = NamedDimsArrayStyle {N,NDA} ()
721743
722744function Broadcast. BroadcastStyle (arraytype:: Type{<:AbstractNamedDimsArray} )
723- return NamedDimsArrayStyle {ndims(arraytype)} ()
745+ return NamedDimsArrayStyle {ndims(arraytype),constructorof(arraytype) } ()
724746end
725747
726748function Broadcast. combine_axes (
@@ -762,6 +784,24 @@ function set_promote_shape(
762784 return named .(ax_promoted, name .(ax1))
763785end
764786
787+ # Handle operations like `ITensor() + ITensor(i, j)`.
788+ # TODO : Decide if this should be a general definition for `AbstractNamedDimsArray`,
789+ # or just for `AbstractITensor`.
790+ function set_promote_shape (
791+ ax1:: Tuple{} , ax2:: Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}}
792+ )
793+ return ax2
794+ end
795+
796+ # Handle operations like `ITensor(i, j) + ITensor()`.
797+ # TODO : Decide if this should be a general definition for `AbstractNamedDimsArray`,
798+ # or just for `AbstractITensor`.
799+ function set_promote_shape (
800+ ax1:: Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}} , ax2:: Tuple{}
801+ )
802+ return ax1
803+ end
804+
765805function Broadcast. check_broadcast_shape (ax1:: NaiveOrderedSet , ax2:: NaiveOrderedSet )
766806 return set_check_broadcast_shape (Tuple (ax1), Tuple (ax2))
767807end
@@ -775,6 +815,7 @@ function set_check_broadcast_shape(
775815 check_broadcast_shape (dename .(ax1), dename .(ax2_aligned))
776816 return nothing
777817end
818+ set_check_broadcast_shape (ax1:: Tuple{} , ax2:: Tuple{} ) = nothing
778819
779820# Dename and lazily permute the arguments using the reference
780821# dimension names.
@@ -783,19 +824,33 @@ function denamed(m::Mapped, nameddimsindices)
783824 return mapped (m. f, map (arg -> denamed (arg, nameddimsindices), m. args)... )
784825end
785826
827+ function nameddimsarraytype (style:: NamedDimsArrayStyle{<:Any,NDA} ) where {NDA}
828+ return NDA
829+ end
830+
831+ using FillArrays: Fill
832+
833+ function MapBroadcast. tile (a:: AbstractNamedDimsArray , ax)
834+ axes (a) == ax && return a
835+ if iszero (ndims (a))
836+ return constructorof (typeof (a))(Fill (a[], dename .(Tuple (ax))), name .(ax))
837+ end
838+ return error (" Not implemented." )
839+ end
840+
786841function Base. similar (bc:: Broadcasted{<:AbstractNamedDimsArrayStyle} , elt:: Type , ax)
787842 nameddimsindices = name .(ax)
788843 m′ = denamed (Mapped (bc), nameddimsindices)
789844 # TODO : Store the wrapper type in `AbstractNamedDimsArrayStyle` and use that
790845 # wrapper type rather than the generic `nameddims` constructor, which
791846 # can lose information.
792847 # Call it as `nameddimsarraytype(bc.style)`.
793- return nameddims (similar (m′, elt, dename .(Tuple (ax))), nameddimsindices)
848+ return nameddimsarraytype (bc. style)(
849+ similar (m′, elt, dename .(Tuple (ax))), nameddimsindices
850+ )
794851end
795852
796- function Base. copyto! (
797- dest:: AbstractArray{<:Any,N} , bc:: Broadcasted{<:AbstractNamedDimsArrayStyle{N}}
798- ) where {N}
853+ function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:AbstractNamedDimsArrayStyle} )
799854 return copyto! (dest, Mapped (bc))
800855end
801856
0 commit comments