Skip to content

Commit acd2746

Browse files
committed
Expect array of arrays for now
1 parent 2150d93 commit acd2746

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

NDTensors/src/lib/NestedPermutedDimsArrays/src/NestedPermutedDimsArrays.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,16 @@ Base.@constprop :aggressive function NestedPermutedDimsArray(
5252
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
5353
iperm = invperm(perm)
5454
return NestedPermutedDimsArray{
55-
maybe_permuteddimsarraytype(T, perm),N,(perm...,),(iperm...,),typeof(data)
55+
PermutedDimsArray{eltype(T),N,(perm...,),(iperm...,),T},
56+
N,
57+
(perm...,),
58+
(iperm...,),
59+
typeof(data),
5660
}(
5761
data
5862
)
5963
end
6064

61-
# Ideally would use `Base.promote_op(maybe_permuteddimsarraytype, type, perm)`
62-
# but it doesn't handle `perm` properly.
63-
function maybe_permuteddimsarraytype(type::Type{<:AbstractArray}, perm)
64-
return PermutedDimsArray{eltype(type),ndims(type),perm,invperm(perm),type}
65-
end
66-
maybe_permuteddimsarraytype(type::Type, perm) = type
67-
68-
function maybe_permuteddimsarray(A::AbstractArray, perm)
69-
return PermutedDimsArray(A, perm)
70-
end
71-
# By default, assume scalar and don't permute.
72-
maybe_permuteddimsarray(x, perm) = x
73-
7465
Base.parent(A::NestedPermutedDimsArray) = A.parent
7566
function Base.size(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
7667
return genperm(size(parent(A)), perm)
@@ -109,14 +100,14 @@ end
109100
A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
110101
) where {T,N,perm,iperm}
111102
@boundscheck checkbounds(A, I...)
112-
@inbounds val = maybe_permuteddimsarray(getindex(A.parent, genperm(I, iperm)...), perm)
103+
@inbounds val = PermutedDimsArray(getindex(A.parent, genperm(I, iperm)...), perm)
113104
return val
114105
end
115106
@inline function Base.setindex!(
116107
A::NestedPermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N}
117108
) where {T,N,perm,iperm}
118109
@boundscheck checkbounds(A, I...)
119-
@inbounds setindex!(A.parent, maybe_permuteddimsarray(val, perm), genperm(I, iperm)...)
110+
@inbounds setindex!(A.parent, PermutedDimsArray(val, perm), genperm(I, iperm)...)
120111
return val
121112
end
122113

0 commit comments

Comments
 (0)