@@ -31,7 +31,7 @@ function KroneckerArray(a::AbstractArray, b::AbstractArray)
3131 )
3232 end
3333 elt = promote_type (eltype (a), eltype (b))
34- return KroneckerArray ( _convert (AbstractArray{elt}, a), _convert (AbstractArray{elt}, b) )
34+ return _convert (AbstractArray{elt}, a) ⊗ _convert (AbstractArray{elt}, b)
3535end
3636const KroneckerMatrix{T,A<: AbstractMatrix{T} ,B<: AbstractMatrix{T} } = KroneckerArray{T,2 ,A,B}
3737const KroneckerVector{T,A<: AbstractVector{T} ,B<: AbstractVector{T} } = KroneckerArray{T,1 ,A,B}
@@ -70,63 +70,69 @@ function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where
7070 return _convert (A, arg1 (a)) ⊗ _convert (B, arg2 (a))
7171end
7272
73- # Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.
74- function _similar (a:: AbstractArray , elt:: Type , axs:: Tuple )
75- return similar (a, elt, axs)
76- end
77- function _similar (a:: AbstractArray , ax:: Tuple )
78- return _similar (a, eltype (a), ax)
79- end
80- function _similar (a:: AbstractArray , elt:: Type )
81- return _similar (a, elt, axes (a))
82- end
83- function _similar (a:: AbstractArray )
84- return _similar (a, eltype (a), axes (a))
85- end
86- function _similar (arrayt:: Type{<:AbstractArray} , axs:: Tuple )
87- return similar (arrayt, axs)
88- end
89-
90- function Base. similar (
91- a:: AbstractArray ,
92- elt:: Type ,
93- axs:: Tuple {
94- CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
95- },
96- )
97- return _similar (a, elt, map (arg1, axs)) ⊗ _similar (a, elt, map (arg2, axs))
98- end
9973function Base. similar (
10074 a:: KroneckerArray ,
10175 elt:: Type ,
10276 axs:: Tuple {
10377 CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
10478 },
10579)
106- return _similar (arg1 (a), elt, map (arg1, axs)) ⊗ _similar (arg2 (a), elt, map (arg2, axs))
80+ return similar (arg1 (a), elt, map (arg1, axs)) ⊗ similar (arg2 (a), elt, map (arg2, axs))
81+ end
82+ function Base. similar (a:: KroneckerArray , elt:: Type )
83+ # TODO : Is this a good definition?
84+ return if isactive (arg1 (a)) == isactive (arg2 (a))
85+ similar (arg1 (a), elt) ⊗ similar (arg2 (a), elt)
86+ elseif isactive (arg1 (a))
87+ similar (arg1 (a), elt) ⊗ elt .(arg2 (a))
88+ elseif isactive (arg2 (a))
89+ elt .(arg1 (a)) ⊗ similar (arg2 (a), elt)
90+ end
91+ end
92+ function Base. similar (a:: KroneckerArray )
93+ # TODO : Is this a good definition?
94+ return if isactive (arg1 (a)) == isactive (arg2 (a))
95+ similar (arg1 (a)) ⊗ similar (arg2 (a))
96+ elseif isactive (arg1 (a))
97+ similar (arg1 (a)) ⊗ arg2 (a)
98+ elseif isactive (arg2 (a))
99+ arg1 (a) ⊗ similar (arg2 (a))
100+ end
107101end
102+
108103function Base. similar (
109- arrayt:: Type{<:AbstractArray} ,
104+ a:: AbstractArray ,
105+ elt:: Type ,
110106 axs:: Tuple {
111107 CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
112108 },
113109)
114- return _similar (arrayt, map (arg1, axs)) ⊗ _similar (arrayt , map (arg2, axs))
110+ return similar (a, elt, map (arg1, axs)) ⊗ similar (a, elt , map (arg2, axs))
115111end
112+
116113function Base. similar (
117114 arrayt:: Type{<:KroneckerArray{<:Any,<:Any,A,B}} ,
118115 axs:: Tuple {
119116 CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
120117 },
121118) where {A,B}
122- return _similar (A, map (arg1, axs)) ⊗ _similar (B, map (arg2, axs))
119+ return similar (A, map (arg1, axs)) ⊗ similar (B, map (arg2, axs))
123120end
124121function Base. similar (
125122 :: Type{<:KroneckerArray{<:Any,<:Any,A,B}} , sz:: Tuple{Int,Vararg{Int}}
126123) where {A,B}
127124 return similar (promote_type (A, B), sz)
128125end
129126
127+ function Base. similar (
128+ arrayt:: Type{<:AbstractArray} ,
129+ axs:: Tuple {
130+ CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
131+ },
132+ )
133+ return similar (arrayt, map (arg1, axs)) ⊗ similar (arrayt, map (arg2, axs))
134+ end
135+
130136function Base. permutedims (a:: KroneckerArray , perm)
131137 return permutedims (arg1 (a), perm) ⊗ permutedims (arg2 (a), perm)
132138end
@@ -168,7 +174,17 @@ kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b)
168174# Eagerly collect arguments to make more general on GPU.
169175Base. collect (a:: KroneckerArray ) = kron_nd (collect (arg1 (a)), collect (arg2 (a)))
170176
171- Base. zero (a:: KroneckerArray ) = zero (arg1 (a)) ⊗ zero (arg2 (a))
177+ function Base. zero (a:: KroneckerArray )
178+ return if isactive (arg1 (a)) == isactive (arg2 (a))
179+ # TODO : Maybe this should zero both arguments?
180+ # This is how `a * false` would behave.
181+ arg1 (a) ⊗ zero (arg2 (a))
182+ elseif isactive (arg1 (a))
183+ zero (arg1 (a)) ⊗ arg2 (a)
184+ elseif isactive (arg2 (a))
185+ arg1 (a) ⊗ zero (arg2 (a))
186+ end
187+ end
172188
173189using DerivableInterfaces: DerivableInterfaces, zero!
174190function DerivableInterfaces. zero! (a:: KroneckerArray )
@@ -240,19 +256,15 @@ function Base.to_indices(
240256 return I1 .× I2
241257end
242258
243- # Allow customizing for `FillArrays.Eye`.
244- _getindex (a:: AbstractArray , I... ) = a[I... ]
245259function Base. getindex (
246260 a:: KroneckerArray{<:Any,N} , I:: Vararg{Union{CartesianPair,CartesianProduct},N}
247261) where {N}
248262 I′ = to_indices (a, I)
249- return _getindex ( arg1 (a), arg1 .(I′)... ) ⊗ _getindex ( arg2 (a), arg2 .(I′)... )
263+ return arg1 (a)[ arg1 .(I′)... ] ⊗ arg2 (a)[ arg2 .(I′)... ]
250264end
251265# Fix ambigiuity error.
252266Base. getindex (a:: KroneckerArray{<:Any,0} ) = arg1 (a)[] * arg2 (a)[]
253267
254- # Allow customizing for `FillArrays.Eye`.
255- _view (a:: AbstractArray , I... ) = view (a, I... )
256268arg1 (:: Colon ) = (:)
257269arg2 (:: Colon ) = (:)
258270arg1 (:: Base.Slice ) = (:)
@@ -261,13 +273,13 @@ function Base.view(
261273 a:: KroneckerArray{<:Any,N} ,
262274 I:: Vararg{Union{CartesianProduct,CartesianProductUnitRange,Base.Slice,Colon},N} ,
263275) where {N}
264- return _view (arg1 (a), arg1 .(I)... ) ⊗ _view (arg2 (a), arg2 .(I)... )
276+ return view (arg1 (a), arg1 .(I)... ) ⊗ view (arg2 (a), arg2 .(I)... )
265277end
266278function Base. view (a:: KroneckerArray{<:Any,N} , I:: Vararg{CartesianPair,N} ) where {N}
267- return _view (arg1 (a), arg1 .(I)... ) ⊗ _view (arg2 (a), arg2 .(I)... )
279+ return view (arg1 (a), arg1 .(I)... ) ⊗ view (arg2 (a), arg2 .(I)... )
268280end
269281# Fix ambigiuity error.
270- Base. view (a:: KroneckerArray{<:Any,0} ) = _view (arg1 (a)) * _view (arg2 (a))
282+ Base. view (a:: KroneckerArray{<:Any,0} ) = view (arg1 (a)) ⊗ view (arg2 (a))
271283
272284function Base.:(== )(a:: KroneckerArray , b:: KroneckerArray )
273285 return arg1 (a) == arg1 (b) && arg2 (a) == arg2 (b)
0 commit comments