1
1
using Derive: Derive, @derive , AbstractArrayInterface
2
+ using TypeParameterAccessors: unspecify_type_parameters
2
3
3
4
# Some of the interface is inspired by:
4
5
# https://github.com/ITensor/ITensors.jl
@@ -138,7 +139,9 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange)
138
139
return findfirst (== (first (x)), y): findfirst (== (last (x)), y)
139
140
end
140
141
141
- Base. copy (a:: AbstractNamedDimsArray ) = nameddims (copy (dename (a)), nameddimsindices (a))
142
+ function Base. copy (a:: AbstractNamedDimsArray )
143
+ return nameddimsarraytype (a)(copy (dename (a)), nameddimsindices (a))
144
+ end
142
145
143
146
const NamedDimsIndices = Union{
144
147
AbstractNamedUnitRange{<: Integer },AbstractNamedArray{<: Integer }
@@ -174,28 +177,30 @@ using Base.Broadcast: Broadcasted, Style
174
177
struct NaiveOrderedSet{Values}
175
178
values:: Values
176
179
end
177
- Base. Tuple (s:: NaiveOrderedSet ) = s. values
178
- Base. length (s:: NaiveOrderedSet ) = length (Tuple (s))
179
- Base. axes (s:: NaiveOrderedSet ) = axes (Tuple (s))
180
- Base.:(== )(s1:: NaiveOrderedSet , s2:: NaiveOrderedSet ) = issetequal (Tuple (s1), Tuple (s2))
181
- Base. iterate (s:: NaiveOrderedSet , args... ) = iterate (Tuple (s), args... )
182
- Base. getindex (s:: NaiveOrderedSet , I:: Int ) = Tuple (s)[I]
183
- Base. invperm (s:: NaiveOrderedSet ) = NaiveOrderedSet (invperm (Tuple (s)))
180
+ Base. values (s:: NaiveOrderedSet ) = s. values
181
+ Base. Tuple (s:: NaiveOrderedSet ) = Tuple (values (s))
182
+ Base. length (s:: NaiveOrderedSet ) = length (values (s))
183
+ Base. axes (s:: NaiveOrderedSet ) = axes (values (s))
184
+ Base.:(== )(s1:: NaiveOrderedSet , s2:: NaiveOrderedSet ) = issetequal (values (s1), values (s2))
185
+ Base. iterate (s:: NaiveOrderedSet , args... ) = iterate (values (s), args... )
186
+ Base. getindex (s:: NaiveOrderedSet , I:: Int ) = values (s)[I]
187
+ Base. invperm (s:: NaiveOrderedSet ) = NaiveOrderedSet (invperm (values (s)))
184
188
Base. Broadcast. _axes (:: Broadcasted , axes:: NaiveOrderedSet ) = axes
185
189
Base. Broadcast. BroadcastStyle (:: Type{<:NaiveOrderedSet} ) = Style {NaiveOrderedSet} ()
186
190
Base. Broadcast. broadcastable (s:: NaiveOrderedSet ) = s
191
+ Base. to_shape (s:: NaiveOrderedSet ) = s
187
192
188
193
function Base. copy (
189
194
bc:: Broadcasted{Style{NaiveOrderedSet},<:Any,<:Any,<:Tuple{<:NaiveOrderedSet}}
190
195
)
191
- return NaiveOrderedSet (bc. f .(Tuple (only (bc. args))))
196
+ return NaiveOrderedSet (bc. f .(values (only (bc. args))))
192
197
end
193
198
# Multiple arguments not supported.
194
199
function Base. copy (bc:: Broadcasted{Style{NaiveOrderedSet}} )
195
200
return error (" This broadcasting expression of `NaiveOrderedSet` is not supported." )
196
201
end
197
202
function Base. map (f:: Function , s:: NaiveOrderedSet )
198
- return NaiveOrderedSet (map (f, Tuple (s)))
203
+ return NaiveOrderedSet (map (f, values (s)))
199
204
end
200
205
201
206
function Base. axes (a:: AbstractNamedDimsArray )
@@ -228,15 +233,36 @@ to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims)
228
233
to_nameddimsaxis (ax:: NamedDimsAxis ) = ax
229
234
to_nameddimsaxis (I:: NamedDimsIndices ) = named (dename (only (axes (I))), I)
230
235
236
+ nameddimsarraytype (a:: AbstractNamedDimsArray ) = nameddimsarraytype (typeof (a))
237
+ nameddimsarraytype (a:: Type{<:AbstractNamedDimsArray} ) = unspecify_type_parameters (a)
238
+
239
+ function similar_nameddims (a:: AbstractNamedDimsArray , elt:: Type , inds)
240
+ ax = to_nameddimsaxes (inds)
241
+ return nameddimsarraytype (a)(similar (dename (a), elt, dename .(Tuple (ax))), name .(ax))
242
+ end
243
+
244
+ function similar_nameddims (a:: AbstractArray , elt:: Type , inds)
245
+ ax = to_nameddimsaxes (inds)
246
+ return nameddims (similar (a, elt, dename .(Tuple (ax))), name .(ax))
247
+ end
248
+
249
+ # Base.similar gets the eltype at compile time.
250
+ function Base. similar (a:: AbstractNamedDimsArray )
251
+ return similar (a, eltype (a))
252
+ end
253
+
231
254
function Base. similar (
232
255
a:: AbstractArray , elt:: Type , inds:: Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}
233
256
)
234
- ax = to_nameddimsaxes (inds)
235
- return nameddims (similar (unname (a), elt, dename .(ax)), name .(ax))
257
+ return similar_nameddims (a, elt, inds)
258
+ end
259
+
260
+ function Base. similar (a:: AbstractArray , elt:: Type , inds:: NaiveOrderedSet )
261
+ return similar_nameddims (a, elt, inds)
236
262
end
237
263
238
264
function setnameddimsindices (a:: AbstractNamedDimsArray , nameddimsindices)
239
- return nameddims (dename (a), nameddimsindices)
265
+ return nameddimsarraytype (a) (dename (a), nameddimsindices)
240
266
end
241
267
function replacenameddimsindices (f, a:: AbstractNamedDimsArray )
242
268
return setnameddimsindices (a, replace (f, nameddimsindices (a)))
@@ -530,7 +556,7 @@ function Base.setindex!(
530
556
Irest:: NamedViewIndex... ,
531
557
)
532
558
I = (I1, Irest... )
533
- setindex! (a, nameddims (value, I), I... )
559
+ setindex! (a, nameddimsarraytype (a) (value, I), I... )
534
560
return a
535
561
end
536
562
function Base. setindex! (
@@ -560,7 +586,7 @@ function aligndims(a::AbstractArray, dims)
560
586
" Dimension name mismatch $(nameddimsindices (a)) , $(new_nameddimsindices) ."
561
587
),
562
588
)
563
- return nameddims (permutedims (dename (a), perm), new_nameddimsindices)
589
+ return nameddimsarraytype (a) (permutedims (dename (a), perm), new_nameddimsindices)
564
590
end
565
591
566
592
function aligneddims (a:: AbstractArray , dims)
@@ -572,7 +598,7 @@ function aligneddims(a::AbstractArray, dims)
572
598
" Dimension name mismatch $(nameddimsindices (a)) , $(new_nameddimsindices) ."
573
599
),
574
600
)
575
- return nameddims (PermutedDimsArray (dename (a), perm), new_nameddimsindices)
601
+ return nameddimsarraytype (a) (PermutedDimsArray (dename (a), perm), new_nameddimsindices)
576
602
end
577
603
578
604
# Convenient constructors
634
660
for f in [:zeros , :ones ]
635
661
@eval begin
636
662
function Base. $f (
637
- elt:: Type{<:Number} , ax :: Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}
663
+ elt:: Type{<:Number} , inds :: Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}
638
664
)
639
665
ax = to_nameddimsaxes (inds)
640
666
a = $ f (elt, dename .(ax))
760
786
function Base. similar (bc:: Broadcasted{<:AbstractNamedDimsArrayStyle} , elt:: Type , ax)
761
787
nameddimsindices = name .(ax)
762
788
m′ = denamed (Mapped (bc), nameddimsindices)
789
+ # TODO : Store the wrapper type in `AbstractNamedDimsArrayStyle` and use that
790
+ # wrapper type rather than the generic `nameddims` constructor, which
791
+ # can lose information.
792
+ # Call it as `nameddimsarraytype(bc.style)`.
763
793
return nameddims (similar (m′, elt, dename .(Tuple (ax))), nameddimsindices)
764
794
end
765
795
0 commit comments