11using Derive: Derive, @derive , AbstractArrayInterface
2+ using TypeParameterAccessors: unspecify_type_parameters
23
34# Some of the interface is inspired by:
45# https://github.com/ITensor/ITensors.jl
@@ -138,7 +139,9 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange)
138139 return findfirst (== (first (x)), y): findfirst (== (last (x)), y)
139140end
140141
141- Base. copy (a:: AbstractNamedDimsArray ) = nameddims (copy (dename (a)), nameddimsindices (a))
142+ function Base. copy (a:: AbstractNamedDimsArray )
143+ return nameddimsarraytype (a)(copy (dename (a)), nameddimsindices (a))
144+ end
142145
143146const NamedDimsIndices = Union{
144147 AbstractNamedUnitRange{<: Integer },AbstractNamedArray{<: Integer }
@@ -174,28 +177,30 @@ using Base.Broadcast: Broadcasted, Style
174177struct NaiveOrderedSet{Values}
175178 values:: Values
176179end
177- Base. Tuple (s:: NaiveOrderedSet ) = s. values
178- Base. length (s:: NaiveOrderedSet ) = length (Tuple (s))
179- Base. axes (s:: NaiveOrderedSet ) = axes (Tuple (s))
180- Base.:(== )(s1:: NaiveOrderedSet , s2:: NaiveOrderedSet ) = issetequal (Tuple (s1), Tuple (s2))
181- Base. iterate (s:: NaiveOrderedSet , args... ) = iterate (Tuple (s), args... )
182- Base. getindex (s:: NaiveOrderedSet , I:: Int ) = Tuple (s)[I]
183- Base. invperm (s:: NaiveOrderedSet ) = NaiveOrderedSet (invperm (Tuple (s)))
180+ Base. values (s:: NaiveOrderedSet ) = s. values
181+ Base. Tuple (s:: NaiveOrderedSet ) = Tuple (values (s))
182+ Base. length (s:: NaiveOrderedSet ) = length (values (s))
183+ Base. axes (s:: NaiveOrderedSet ) = axes (values (s))
184+ Base.:(== )(s1:: NaiveOrderedSet , s2:: NaiveOrderedSet ) = issetequal (values (s1), values (s2))
185+ Base. iterate (s:: NaiveOrderedSet , args... ) = iterate (values (s), args... )
186+ Base. getindex (s:: NaiveOrderedSet , I:: Int ) = values (s)[I]
187+ Base. invperm (s:: NaiveOrderedSet ) = NaiveOrderedSet (invperm (values (s)))
184188Base. Broadcast. _axes (:: Broadcasted , axes:: NaiveOrderedSet ) = axes
185189Base. Broadcast. BroadcastStyle (:: Type{<:NaiveOrderedSet} ) = Style {NaiveOrderedSet} ()
186190Base. Broadcast. broadcastable (s:: NaiveOrderedSet ) = s
191+ Base. to_shape (s:: NaiveOrderedSet ) = s
187192
188193function Base. copy (
189194 bc:: Broadcasted{Style{NaiveOrderedSet},<:Any,<:Any,<:Tuple{<:NaiveOrderedSet}}
190195)
191- return NaiveOrderedSet (bc. f .(Tuple (only (bc. args))))
196+ return NaiveOrderedSet (bc. f .(values (only (bc. args))))
192197end
193198# Multiple arguments not supported.
194199function Base. copy (bc:: Broadcasted{Style{NaiveOrderedSet}} )
195200 return error (" This broadcasting expression of `NaiveOrderedSet` is not supported." )
196201end
197202function Base. map (f:: Function , s:: NaiveOrderedSet )
198- return NaiveOrderedSet (map (f, Tuple (s)))
203+ return NaiveOrderedSet (map (f, values (s)))
199204end
200205
201206function Base. axes (a:: AbstractNamedDimsArray )
@@ -228,15 +233,36 @@ to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims)
228233to_nameddimsaxis (ax:: NamedDimsAxis ) = ax
229234to_nameddimsaxis (I:: NamedDimsIndices ) = named (dename (only (axes (I))), I)
230235
236+ nameddimsarraytype (a:: AbstractNamedDimsArray ) = nameddimsarraytype (typeof (a))
237+ nameddimsarraytype (a:: Type{<:AbstractNamedDimsArray} ) = unspecify_type_parameters (a)
238+
239+ function similar_nameddims (a:: AbstractNamedDimsArray , elt:: Type , inds)
240+ ax = to_nameddimsaxes (inds)
241+ return nameddimsarraytype (a)(similar (dename (a), elt, dename .(Tuple (ax))), name .(ax))
242+ end
243+
244+ function similar_nameddims (a:: AbstractArray , elt:: Type , inds)
245+ ax = to_nameddimsaxes (inds)
246+ return nameddims (similar (a, elt, dename .(Tuple (ax))), name .(ax))
247+ end
248+
249+ # Base.similar gets the eltype at compile time.
250+ function Base. similar (a:: AbstractNamedDimsArray )
251+ return similar (a, eltype (a))
252+ end
253+
231254function Base. similar (
232255 a:: AbstractArray , elt:: Type , inds:: Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}
233256)
234- ax = to_nameddimsaxes (inds)
235- return nameddims (similar (unname (a), elt, dename .(ax)), name .(ax))
257+ return similar_nameddims (a, elt, inds)
258+ end
259+
260+ function Base. similar (a:: AbstractArray , elt:: Type , inds:: NaiveOrderedSet )
261+ return similar_nameddims (a, elt, inds)
236262end
237263
238264function setnameddimsindices (a:: AbstractNamedDimsArray , nameddimsindices)
239- return nameddims (dename (a), nameddimsindices)
265+ return nameddimsarraytype (a) (dename (a), nameddimsindices)
240266end
241267function replacenameddimsindices (f, a:: AbstractNamedDimsArray )
242268 return setnameddimsindices (a, replace (f, nameddimsindices (a)))
@@ -530,7 +556,7 @@ function Base.setindex!(
530556 Irest:: NamedViewIndex... ,
531557)
532558 I = (I1, Irest... )
533- setindex! (a, nameddims (value, I), I... )
559+ setindex! (a, nameddimsarraytype (a) (value, I), I... )
534560 return a
535561end
536562function Base. setindex! (
@@ -560,7 +586,7 @@ function aligndims(a::AbstractArray, dims)
560586 " Dimension name mismatch $(nameddimsindices (a)) , $(new_nameddimsindices) ."
561587 ),
562588 )
563- return nameddims (permutedims (dename (a), perm), new_nameddimsindices)
589+ return nameddimsarraytype (a) (permutedims (dename (a), perm), new_nameddimsindices)
564590end
565591
566592function aligneddims (a:: AbstractArray , dims)
@@ -572,7 +598,7 @@ function aligneddims(a::AbstractArray, dims)
572598 " Dimension name mismatch $(nameddimsindices (a)) , $(new_nameddimsindices) ."
573599 ),
574600 )
575- return nameddims (PermutedDimsArray (dename (a), perm), new_nameddimsindices)
601+ return nameddimsarraytype (a) (PermutedDimsArray (dename (a), perm), new_nameddimsindices)
576602end
577603
578604# Convenient constructors
634660for f in [:zeros , :ones ]
635661 @eval begin
636662 function Base. $f (
637- elt:: Type{<:Number} , ax :: Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}
663+ elt:: Type{<:Number} , inds :: Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}
638664 )
639665 ax = to_nameddimsaxes (inds)
640666 a = $ f (elt, dename .(ax))
760786function Base. similar (bc:: Broadcasted{<:AbstractNamedDimsArrayStyle} , elt:: Type , ax)
761787 nameddimsindices = name .(ax)
762788 m′ = denamed (Mapped (bc), nameddimsindices)
789+ # TODO : Store the wrapper type in `AbstractNamedDimsArrayStyle` and use that
790+ # wrapper type rather than the generic `nameddims` constructor, which
791+ # can lose information.
792+ # Call it as `nameddimsarraytype(bc.style)`.
763793 return nameddims (similar (m′, elt, dename .(Tuple (ax))), nameddimsindices)
764794end
765795
0 commit comments