@@ -118,7 +118,29 @@ function Base.similar(
118118 return similar (promote_type (A, B), sz)
119119end
120120
121- Base. collect (a:: KroneckerArray ) = kron (a. a, a. b)
121+ function flatten (t:: Tuple{Tuple,Tuple,Vararg{Tuple}} )
122+ return (t[1 ]. .. , flatten (Base. tail (t))... )
123+ end
124+ function flatten (t:: Tuple{Tuple} )
125+ return t[1 ]
126+ end
127+ flatten (:: Tuple{} ) = ()
128+ function interleave (x:: Tuple , y:: Tuple )
129+ length (x) == length (y) || throw (ArgumentError (" Tuples must have the same length." ))
130+ xy = ntuple (i -> (x[i], y[i]), length (x))
131+ return flatten (xy)
132+ end
133+ function kron_nd (a:: AbstractArray{<:Any,N} , b:: AbstractArray{<:Any,N} ) where {N}
134+ a′ = reshape (a, interleave (size (a), ntuple (one, N)))
135+ b′ = reshape (b, interleave (ntuple (one, N), size (b)))
136+ c′ = permutedims (a′ .* b′, reverse (ntuple (identity, 2 N)))
137+ sz = ntuple (i -> size (a, i) * size (b, i), N)
138+ return permutedims (reshape (c′, sz), reverse (ntuple (identity, N)))
139+ end
140+ kron_nd (a:: AbstractMatrix , b:: AbstractMatrix ) = kron (a, b)
141+ kron_nd (a:: AbstractVector , b:: AbstractVector ) = kron (a, b)
142+
143+ Base. collect (a:: KroneckerArray ) = kron_nd (a. a, a. b)
122144
123145function Base. Array {T,N} (a:: KroneckerArray{S,N} ) where {T,S,N}
124146 return convert (Array{T,N}, collect (a))
@@ -150,10 +172,18 @@ function Base.show(io::IO, a::KroneckerArray)
150172 return nothing
151173end
152174
153- ⊗ (a:: AbstractVecOrMat , b:: AbstractVecOrMat ) = KroneckerArray (a, b)
175+ ⊗ (a:: AbstractArray , b:: AbstractArray ) = KroneckerArray (a, b)
154176⊗ (a:: Number , b:: Number ) = a * b
155- ⊗ (a:: Number , b:: AbstractVecOrMat ) = a * b
156- ⊗ (a:: AbstractVecOrMat , b:: Number ) = a * b
177+ ⊗ (a:: Number , b:: AbstractArray ) = a * b
178+ ⊗ (a:: AbstractArray , b:: Number ) = a * b
179+
180+ function Base. getindex (a:: KroneckerArray , i:: Integer )
181+ return a[CartesianIndices (a)[i]]
182+ end
183+
184+ function Base. getindex (a:: KroneckerArray{<:Any,N} , I:: Vararg{Integer,N} ) where {N}
185+ return error (" Not implemented." )
186+ end
157187
158188function Base. getindex (a:: KroneckerMatrix , i1:: Integer , i2:: Integer )
159189 GPUArraysCore. assertscalar (" getindex" )
@@ -162,22 +192,21 @@ function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
162192 k, l = size (a. b)
163193 return a. a[cld (i1, k), cld (i2, l)] * a. b[(i1 - 1 ) % k + 1 , (i2 - 1 ) % l + 1 ]
164194end
165- function Base. getindex (a:: KroneckerMatrix , i:: Integer )
166- return a[CartesianIndices (a)[i]]
167- end
168195
169196function Base. getindex (a:: KroneckerVector , i:: Integer )
170197 GPUArraysCore. assertscalar (" getindex" )
171198 k = length (a. b)
172199 return a. a[cld (i, k)] * a. b[(i - 1 ) % k + 1 ]
173200end
174201
175- function Base. getindex (a:: KroneckerVector , i:: CartesianProduct )
176- return a. a[i. a] ⊗ a. b[i. b]
177- end
178- function Base. getindex (a:: KroneckerMatrix , i :: CartesianProduct , j :: CartesianProduct )
179- return a. a[i . a, j . a ] ⊗ a. b[i . b, j . b ]
202+ # # function Base.getindex(a::KroneckerVector, i::CartesianProduct)
203+ # # return a.a[i.a] ⊗ a.b[i.b]
204+ # # end
205+ function Base. getindex (a:: KroneckerArray{<:Any,N} , I :: Vararg{ CartesianProduct,N} ) where {N}
206+ return a. a[map (Base . Fix2 (getfield, :a ), I) ... ] ⊗ a. b[map (Base . Fix2 (getfield, :b ), I) ... ]
180207end
208+ # Fix ambigiuity error.
209+ Base. getindex (a:: KroneckerArray{<:Any,0} ) = a. a[] * a. b[]
181210
182211function Base.:(== )(a:: KroneckerArray , b:: KroneckerArray )
183212 return a. a == b. a && a. b == b. b
@@ -220,7 +249,7 @@ using LinearAlgebra:
220249 svd,
221250 svdvals,
222251 tr
223- diagonal (a:: AbstractVecOrMat ) = Diagonal (a)
252+ diagonal (a:: AbstractArray ) = Diagonal (a)
224253function diagonal (a:: KroneckerArray )
225254 return Diagonal (a. a) ⊗ Diagonal (a. b)
226255end
@@ -275,10 +304,10 @@ end
275304function Base.:* (a:: KroneckerQ , b:: KroneckerQ )
276305 return (a. a * b. a) ⊗ (a. b * b. b)
277306end
278- function Base.:* (a:: KroneckerQ , b:: KroneckerMatrix )
307+ function Base.:* (a:: KroneckerQ , b:: KroneckerArray )
279308 return (a. a * b. a) ⊗ (a. b * b. b)
280309end
281- function Base.:* (a:: KroneckerMatrix , b:: KroneckerQ )
310+ function Base.:* (a:: KroneckerArray , b:: KroneckerQ )
282311 return (a. a * b. a) ⊗ (a. b * b. b)
283312end
284313function Base. adjoint (a:: KroneckerQ )
0 commit comments