@@ -24,17 +24,17 @@ arg2(a::KroneckerArray) = a.b
2424
2525using Adapt: Adapt, adapt
2626_adapt (to, a:: AbstractArray ) = adapt (to, a)
27- Adapt. adapt_structure (to, a:: KroneckerArray ) = _adapt (to, a . a) ⊗ _adapt (to, a . b )
27+ Adapt. adapt_structure (to, a:: KroneckerArray ) = _adapt (to, arg1 (a)) ⊗ _adapt (to, arg2 (a) )
2828
2929# Allows extra customization, like for `FillArrays.Eye`.
3030_copy (a:: AbstractArray ) = copy (a)
3131
3232function Base. copy (a:: KroneckerArray )
33- return _copy (a . a) ⊗ _copy (a . b )
33+ return _copy (arg1 (a)) ⊗ _copy (arg2 (a) )
3434end
3535function Base. copyto! (dest:: KroneckerArray , src:: KroneckerArray )
36- copyto! (dest. a, src. a )
37- copyto! (dest. b, src. b )
36+ copyto! (arg1 ( dest), arg1 ( src) )
37+ copyto! (arg2 ( dest), arg2 ( src) )
3838 return dest
3939end
4040
@@ -53,8 +53,7 @@ function Base.similar(
5353 CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
5454 },
5555)
56- return _similar (a, elt, map (ax -> ax. product. a, axs)) ⊗
57- _similar (a, elt, map (ax -> ax. product. b, axs))
56+ return _similar (a, elt, map (arg1, axs)) ⊗ _similar (a, elt, map (arg2, axs))
5857end
5958function Base. similar (
6059 a:: KroneckerArray ,
@@ -63,26 +62,23 @@ function Base.similar(
6362 CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
6463 },
6564)
66- return _similar (a. a, elt, map (ax -> ax. product. a, axs)) ⊗
67- _similar (a. b, elt, map (ax -> ax. product. b, axs))
65+ return _similar (arg1 (a), elt, map (arg1, axs)) ⊗ _similar (arg2 (a), elt, map (arg2, axs))
6866end
6967function Base. similar (
7068 arrayt:: Type{<:AbstractArray} ,
7169 axs:: Tuple {
7270 CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
7371 },
7472)
75- return _similar (arrayt, map (ax -> ax. product. a, axs)) ⊗
76- _similar (arrayt, map (ax -> ax. product. b, axs))
73+ return _similar (arrayt, map (arg1, axs)) ⊗ _similar (arrayt, map (arg2, axs))
7774end
7875function Base. similar (
7976 arrayt:: Type{<:KroneckerArray{<:Any,<:Any,A,B}} ,
8077 axs:: Tuple {
8178 CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
8279 },
8380) where {A,B}
84- return _similar (A, map (ax -> ax. product. a, axs)) ⊗
85- _similar (B, map (ax -> ax. product. b, axs))
81+ return _similar (A, map (arg1, axs)) ⊗ _similar (B, map (arg2, axs))
8682end
8783function Base. similar (
8884 :: Type{<:KroneckerArray{<:Any,<:Any,A,B}} , sz:: Tuple{Int,Vararg{Int}}
@@ -115,39 +111,41 @@ kron_nd(a::AbstractMatrix, b::AbstractMatrix) = kron(a, b)
115111kron_nd (a:: AbstractVector , b:: AbstractVector ) = kron (a, b)
116112
117113# Eagerly collect arguments to make more general on GPU.
118- Base. collect (a:: KroneckerArray ) = kron_nd (collect (a . a) , collect (a . b ))
114+ Base. collect (a:: KroneckerArray ) = kron_nd (collect (arg1 (a)) , collect (arg2 (a) ))
119115
120116Base. zero (a:: KroneckerArray ) = zero (arg1 (a)) ⊗ zero (arg2 (a))
121117
122118function Base. Array {T,N} (a:: KroneckerArray{S,N} ) where {T,S,N}
123119 return convert (Array{T,N}, collect (a))
124120end
125121
126- Base. size (a:: KroneckerArray ) = ntuple (dim -> size (a. a, dim) * size (a. b, dim), ndims (a))
122+ function Base. size (a:: KroneckerArray )
123+ return ntuple (dim -> size (arg1 (a), dim) * size (arg2 (a), dim), ndims (a))
124+ end
127125
128126function Base. axes (a:: KroneckerArray )
129127 return ntuple (ndims (a)) do dim
130128 return CartesianProductUnitRange (
131- axes (a . a , dim) × axes (a . b , dim), Base. OneTo (size (a, dim))
129+ axes (arg1 (a) , dim) × axes (arg2 (a) , dim), Base. OneTo (size (a, dim))
132130 )
133131 end
134132end
135133
136- arguments (a:: KroneckerArray ) = (a . a, a . b )
134+ arguments (a:: KroneckerArray ) = (arg1 (a), arg2 (a) )
137135arguments (a:: KroneckerArray , n:: Int ) = arguments (a)[n]
138136argument_types (a:: KroneckerArray ) = argument_types (typeof (a))
139137argument_types (:: Type{<:KroneckerArray{<:Any,<:Any,A,B}} ) where {A,B} = (A, B)
140138
141139function Base. print_array (io:: IO , a:: KroneckerArray )
142- Base. print_array (io, a . a )
140+ Base. print_array (io, arg1 (a) )
143141 println (io, " \n ⊗" )
144- Base. print_array (io, a . b )
142+ Base. print_array (io, arg2 (a) )
145143 return nothing
146144end
147145function Base. show (io:: IO , a:: KroneckerArray )
148- show (io, a . a )
146+ show (io, arg1 (a) )
149147 print (io, " ⊗ " )
150- show (io, a . b )
148+ show (io, arg2 (a) )
151149 return nothing
152150end
153151
@@ -172,14 +170,14 @@ function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
172170 GPUArraysCore. assertscalar (" getindex" )
173171 # Code logic from Kronecker.jl:
174172 # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105
175- k, l = size (a . b )
176- return a . a [cld (i1, k), cld (i2, l)] * a . b [(i1 - 1 ) % k + 1 , (i2 - 1 ) % l + 1 ]
173+ k, l = size (arg2 (a) )
174+ return arg1 (a) [cld (i1, k), cld (i2, l)] * arg2 (a) [(i1 - 1 ) % k + 1 , (i2 - 1 ) % l + 1 ]
177175end
178176
179177function Base. getindex (a:: KroneckerVector , i:: Integer )
180178 GPUArraysCore. assertscalar (" getindex" )
181- k = length (a . b )
182- return a . a [cld (i, k)] * a . b [(i - 1 ) % k + 1 ]
179+ k = length (arg2 (a) )
180+ return arg1 (a) [cld (i, k)] * arg2 (a) [(i - 1 ) % k + 1 ]
183181end
184182
185183# Allow customizing for `FillArrays.Eye`.
@@ -191,49 +189,49 @@ function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) w
191189 return _getindex (arg1 (a), arg1 .(I)... ) ⊗ _getindex (arg2 (a), arg2 .(I)... )
192190end
193191# Fix ambigiuity error.
194- Base. getindex (a:: KroneckerArray{<:Any,0} ) = a . a [] * a . b []
192+ Base. getindex (a:: KroneckerArray{<:Any,0} ) = arg1 (a) [] * arg2 (a) []
195193
196194function Base.:(== )(a:: KroneckerArray , b:: KroneckerArray )
197- return a . a == b . a && a . b == b . b
195+ return arg1 (a) == arg1 (b) && arg2 (a) == arg2 (b)
198196end
199197function Base. isapprox (a:: KroneckerArray , b:: KroneckerArray ; kwargs... )
200- return isapprox (a . a, b . a ; kwargs... ) && isapprox (a . b, b . b ; kwargs... )
198+ return isapprox (arg1 (a), arg1 (b) ; kwargs... ) && isapprox (arg2 (a), arg2 (b) ; kwargs... )
201199end
202200function Base. iszero (a:: KroneckerArray )
203- return iszero (a . a) || iszero (a . b )
201+ return iszero (arg1 (a)) || iszero (arg2 (a) )
204202end
205203function Base. isreal (a:: KroneckerArray )
206- return isreal (a . a) && isreal (a . b )
204+ return isreal (arg1 (a)) && isreal (arg2 (a) )
207205end
208206
209207using DiagonalArrays: DiagonalArrays, diagonal
210208function DiagonalArrays. diagonal (a:: KroneckerArray )
211- return diagonal (a . a) ⊗ diagonal (a . b )
209+ return diagonal (arg1 (a)) ⊗ diagonal (arg2 (a) )
212210end
213211
214212Base. real (a:: KroneckerArray{<:Real} ) = a
215213function Base. real (a:: KroneckerArray )
216- if iszero (imag (a . a)) || iszero (imag (a . b ))
217- return real (a . a) ⊗ real (a . b )
218- elseif iszero (real (a . a)) || iszero (real (a . b ))
219- return - imag (a . a) ⊗ imag (a . b )
214+ if iszero (imag (arg1 (a))) || iszero (imag (arg2 (a) ))
215+ return real (arg1 (a)) ⊗ real (arg2 (a) )
216+ elseif iszero (real (arg1 (a))) || iszero (real (arg2 (a) ))
217+ return - imag (arg1 (a)) ⊗ imag (arg2 (a) )
220218 end
221- return real (a . a) ⊗ real (a . b) - imag (a . a) ⊗ imag (a . b )
219+ return real (arg1 (a)) ⊗ real (arg2 (a)) - imag (arg1 (a)) ⊗ imag (arg2 (a) )
222220end
223221Base. imag (a:: KroneckerArray{<:Real} ) = zero (a)
224222function Base. imag (a:: KroneckerArray )
225- if iszero (imag (a . a)) || iszero (real (a . b ))
226- return real (a . a) ⊗ imag (a . b )
227- elseif iszero (real (a . a)) || iszero (imag (a . b ))
228- return imag (a . a) ⊗ real (a . b )
223+ if iszero (imag (arg1 (a))) || iszero (real (arg2 (a) ))
224+ return real (arg1 (a)) ⊗ imag (arg2 (a) )
225+ elseif iszero (real (arg1 (a))) || iszero (imag (arg2 (a) ))
226+ return imag (arg1 (a)) ⊗ real (arg2 (a) )
229227 end
230- return real (a . a) ⊗ imag (a . b) + imag (a . a) ⊗ real (a . b )
228+ return real (arg1 (a)) ⊗ imag (arg2 (a)) + imag (arg1 (a)) ⊗ real (arg2 (a) )
231229end
232230
233231for f in [:transpose , :adjoint , :inv ]
234232 @eval begin
235233 function Base. $f (a:: KroneckerArray )
236- return $ f (a . a) ⊗ $ f (a . b )
234+ return $ f (arg1 (a)) ⊗ $ f (arg2 (a) )
237235 end
238236 end
239237end
0 commit comments