@@ -118,7 +118,29 @@ function Base.similar(
118
118
return similar (promote_type (A, B), sz)
119
119
end
120
120
121
- Base. collect (a:: KroneckerArray ) = kron (a. a, a. b)
121
+ function flatten (t:: Tuple{Tuple,Tuple,Vararg{Tuple}} )
122
+ return (t[1 ]. .. , flatten (Base. tail (t))... )
123
+ end
124
+ function flatten (t:: Tuple{Tuple} )
125
+ return t[1 ]
126
+ end
127
+ flatten (:: Tuple{} ) = ()
128
+ function interleave (x:: Tuple , y:: Tuple )
129
+ length (x) == length (y) || throw (ArgumentError (" Tuples must have the same length." ))
130
+ xy = ntuple (i -> (x[i], y[i]), length (x))
131
+ return flatten (xy)
132
+ end
133
+ function kron_nd (a:: AbstractArray{<:Any,N} , b:: AbstractArray{<:Any,N} ) where {N}
134
+ a′ = reshape (a, interleave (size (a), ntuple (one, N)))
135
+ b′ = reshape (b, interleave (ntuple (one, N), size (b)))
136
+ c′ = permutedims (a′ .* b′, reverse (ntuple (identity, 2 N)))
137
+ sz = ntuple (i -> size (a, i) * size (b, i), N)
138
+ return permutedims (reshape (c′, sz), reverse (ntuple (identity, N)))
139
+ end
140
+ kron_nd (a:: AbstractMatrix , b:: AbstractMatrix ) = kron (a, b)
141
+ kron_nd (a:: AbstractVector , b:: AbstractVector ) = kron (a, b)
142
+
143
+ Base. collect (a:: KroneckerArray ) = kron_nd (a. a, a. b)
122
144
123
145
function Base. Array {T,N} (a:: KroneckerArray{S,N} ) where {T,S,N}
124
146
return convert (Array{T,N}, collect (a))
@@ -150,10 +172,18 @@ function Base.show(io::IO, a::KroneckerArray)
150
172
return nothing
151
173
end
152
174
153
- ⊗ (a:: AbstractVecOrMat , b:: AbstractVecOrMat ) = KroneckerArray (a, b)
175
+ ⊗ (a:: AbstractArray , b:: AbstractArray ) = KroneckerArray (a, b)
154
176
⊗ (a:: Number , b:: Number ) = a * b
155
- ⊗ (a:: Number , b:: AbstractVecOrMat ) = a * b
156
- ⊗ (a:: AbstractVecOrMat , b:: Number ) = a * b
177
+ ⊗ (a:: Number , b:: AbstractArray ) = a * b
178
+ ⊗ (a:: AbstractArray , b:: Number ) = a * b
179
+
180
+ function Base. getindex (a:: KroneckerArray , i:: Integer )
181
+ return a[CartesianIndices (a)[i]]
182
+ end
183
+
184
+ function Base. getindex (a:: KroneckerArray{<:Any,N} , I:: Vararg{Integer,N} ) where {N}
185
+ return error (" Not implemented." )
186
+ end
157
187
158
188
function Base. getindex (a:: KroneckerMatrix , i1:: Integer , i2:: Integer )
159
189
GPUArraysCore. assertscalar (" getindex" )
@@ -162,22 +192,21 @@ function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
162
192
k, l = size (a. b)
163
193
return a. a[cld (i1, k), cld (i2, l)] * a. b[(i1 - 1 ) % k + 1 , (i2 - 1 ) % l + 1 ]
164
194
end
165
- function Base. getindex (a:: KroneckerMatrix , i:: Integer )
166
- return a[CartesianIndices (a)[i]]
167
- end
168
195
169
196
function Base. getindex (a:: KroneckerVector , i:: Integer )
170
197
GPUArraysCore. assertscalar (" getindex" )
171
198
k = length (a. b)
172
199
return a. a[cld (i, k)] * a. b[(i - 1 ) % k + 1 ]
173
200
end
174
201
175
- function Base. getindex (a:: KroneckerVector , i:: CartesianProduct )
176
- return a. a[i. a] ⊗ a. b[i. b]
177
- end
178
- function Base. getindex (a:: KroneckerMatrix , i :: CartesianProduct , j :: CartesianProduct )
179
- return a. a[i . a, j . a ] ⊗ a. b[i . b, j . b ]
202
+ # # function Base.getindex(a::KroneckerVector, i::CartesianProduct)
203
+ # # return a.a[i.a] ⊗ a.b[i.b]
204
+ # # end
205
+ function Base. getindex (a:: KroneckerArray{<:Any,N} , I :: Vararg{ CartesianProduct,N} ) where {N}
206
+ return a. a[map (Base . Fix2 (getfield, :a ), I) ... ] ⊗ a. b[map (Base . Fix2 (getfield, :b ), I) ... ]
180
207
end
208
+ # Fix ambigiuity error.
209
+ Base. getindex (a:: KroneckerArray{<:Any,0} ) = a. a[] * a. b[]
181
210
182
211
function Base.:(== )(a:: KroneckerArray , b:: KroneckerArray )
183
212
return a. a == b. a && a. b == b. b
@@ -220,7 +249,7 @@ using LinearAlgebra:
220
249
svd,
221
250
svdvals,
222
251
tr
223
- diagonal (a:: AbstractVecOrMat ) = Diagonal (a)
252
+ diagonal (a:: AbstractArray ) = Diagonal (a)
224
253
function diagonal (a:: KroneckerArray )
225
254
return Diagonal (a. a) ⊗ Diagonal (a. b)
226
255
end
@@ -275,10 +304,10 @@ end
275
304
function Base.:* (a:: KroneckerQ , b:: KroneckerQ )
276
305
return (a. a * b. a) ⊗ (a. b * b. b)
277
306
end
278
- function Base.:* (a:: KroneckerQ , b:: KroneckerMatrix )
307
+ function Base.:* (a:: KroneckerQ , b:: KroneckerArray )
279
308
return (a. a * b. a) ⊗ (a. b * b. b)
280
309
end
281
- function Base.:* (a:: KroneckerMatrix , b:: KroneckerQ )
310
+ function Base.:* (a:: KroneckerArray , b:: KroneckerQ )
282
311
return (a. a * b. a) ⊗ (a. b * b. b)
283
312
end
284
313
function Base. adjoint (a:: KroneckerQ )
0 commit comments