Skip to content

Commit 1208e5b

Browse files
committed
start replacing kroneckerarray with AbstractKroneckerArray
1 parent 13f4e79 commit 1208e5b

File tree

2 files changed

+78
-71
lines changed

2 files changed

+78
-71
lines changed

src/kroneckerarray.jl

Lines changed: 60 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ arg1type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` sub
2323
arg2type(x::AbstractKroneckerArray) = arg2type(typeof(x))
2424
arg2type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` subtypes have to implement `arg2type`.")
2525

26+
arguments(a::AbstractKroneckerArray) = (arg1(a), arg2(a))
27+
arguments(a::AbstractKroneckerArray, n::Int) = arguments(a)[n]
28+
argument_types(a::AbstractKroneckerArray) = argument_types(typeof(a))
29+
2630
function unwrap_array(a::AbstractArray)
2731
p = parent(a)
2832
p a && return a
@@ -51,7 +55,7 @@ function _convert(A::Type{<:Diagonal}, a::AbstractMatrix)
5155
end
5256

5357
struct KroneckerArray{T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}} <:
54-
AbstractKroneckerArray{T, N, A1, A2}
58+
AbstractKroneckerArray{T, N}
5559
arg1::A1
5660
arg2::A2
5761
end
@@ -76,6 +80,8 @@ const KroneckerVector{T, A1 <: AbstractVector{T}, A2 <: AbstractVector{T}} = Kro
7680
arg1type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A1
7781
arg2type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A2
7882

83+
argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2)
84+
7985
function mutate_active_args!(f!, f, dest, src)
8086
(isactive(arg1(dest)) || isactive(arg2(dest))) ||
8187
error("Can't mutate immutable KroneckerArray.")
@@ -93,7 +99,7 @@ function mutate_active_args!(f!, f, dest, src)
9399
end
94100

95101
using Adapt: Adapt, adapt
96-
function Adapt.adapt_structure(to, a::KroneckerArray)
102+
function Adapt.adapt_structure(to, a::AbstractKroneckerArray)
97103
# TODO: Is this a good definition? It is similar to
98104
# the definition of `similar`.
99105
return if isactive(arg1(a)) == isactive(arg2(a))
@@ -105,18 +111,22 @@ function Adapt.adapt_structure(to, a::KroneckerArray)
105111
end
106112
end
107113

108-
function Base.copy(a::KroneckerArray)
109-
return copy(arg1(a)) copy(arg2(a))
114+
Base.copy(a::AbstractKroneckerArray) = copy(arg1(a)) copy(arg2(a))
115+
function Base.copy!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray)
116+
return mutate_active_args!(copy!, copy, dest, src)
110117
end
111118

119+
# TODO: copyto! is typically reserved for contiguous copies (i.e. also for copying from a
120+
# vector into an array), it might be better to not define that here.
112121
function Base.copyto!(dest::KroneckerArray{<:Any, N}, src::KroneckerArray{<:Any, N}) where {N}
113122
return mutate_active_args!(copyto!, copy, dest, src)
114123
end
115124

116125
function Base.convert(
117-
::Type{KroneckerArray{T, N, A1, A2}}, a::KroneckerArray
118-
) where {T, N, A1, A2}
119-
return _convert(A1, arg1(a)) _convert(A2, arg2(a))
126+
::Type{KroneckerArray{T, N, A1, A2}}, a::AbstractKroneckerArray
127+
)::KroneckerArray{T, N, A1, A2} where {T, N, A1, A2}
128+
typeof(a) === KroneckerArray{T, N, A1, A2} && return a
129+
return KroneckerArray(_convert(A1, arg1(a)), _convert(A2, arg2(a)))
120130
end
121131

122132
# Promote the element type if needed.
@@ -125,7 +135,7 @@ end
125135
maybe_promot_eltype(a, elt) = eltype(a) <: elt ? a : elt.(a)
126136

127137
function Base.similar(
128-
a::KroneckerArray,
138+
a::AbstractKroneckerArray,
129139
elt::Type,
130140
axs::Tuple{
131141
CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}},
@@ -142,7 +152,7 @@ function Base.similar(
142152
maybe_promot_eltype(arg1(a), elt) similar(arg2(a), elt, arg2.(axs))
143153
end
144154
end
145-
function Base.similar(a::KroneckerArray, elt::Type)
155+
function Base.similar(a::AbstractKroneckerArray, elt::Type)
146156
# TODO: Is this a good definition?
147157
return if isactive(arg1(a)) == isactive(arg2(a))
148158
similar(arg1(a), elt) similar(arg2(a), elt)
@@ -152,7 +162,7 @@ function Base.similar(a::KroneckerArray, elt::Type)
152162
maybe_promot_eltype(arg1(a), elt) similar(arg2(a), elt)
153163
end
154164
end
155-
function Base.similar(a::KroneckerArray)
165+
function Base.similar(a::AbstractKroneckerArray)
156166
# TODO: Is this a good definition?
157167
return if isactive(arg1(a)) == isactive(arg2(a))
158168
similar(arg1(a)) similar(arg2(a))
@@ -174,16 +184,18 @@ function Base.similar(
174184
end
175185

176186
function Base.similar(
177-
arrayt::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}},
187+
::Type{ArrayT},
178188
axs::Tuple{
179189
CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}},
180190
},
181-
) where {A1, A2}
191+
) where {ArrayT <: AbstractKroneckerArray}
192+
A1, A2 = arg1type(ArrayT), arg2type(ArrayT)
182193
return similar(A1, map(arg1, axs)) similar(A2, map(arg2, axs))
183194
end
184195
function Base.similar(
185-
::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}, sz::Tuple{Int, Vararg{Int}}
186-
) where {A1, A2}
196+
::Type{ArrayT}, sz::Tuple{Int, Vararg{Int}}
197+
) where {ArrayT <: AbstractKroneckerArray}
198+
A1, A2 = arg1type(ArrayT), arg2type(ArrayT)
187199
return similar(promote_type(A1, A2), sz)
188200
end
189201

@@ -196,15 +208,15 @@ function Base.similar(
196208
return similar(arrayt, map(arg1, axs)) similar(arrayt, map(arg2, axs))
197209
end
198210

199-
function Base.permutedims(a::KroneckerArray, perm)
211+
function Base.permutedims(a::AbstractKroneckerArray, perm)
200212
return permutedims(arg1(a), perm) permutedims(arg2(a), perm)
201213
end
202214
using DerivableInterfaces: DerivableInterfaces, permuteddims
203-
function DerivableInterfaces.permuteddims(a::KroneckerArray, perm)
215+
function DerivableInterfaces.permuteddims(a::AbstractKroneckerArray, perm)
204216
return permuteddims(arg1(a), perm) permuteddims(arg2(a), perm)
205217
end
206218

207-
function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm)
219+
function Base.permutedims!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray, perm)
208220
return mutate_active_args!(
209221
(dest, src) -> permutedims!(dest, src, perm), Base.Fix2(permutedims, perm), dest, src
210222
)
@@ -235,9 +247,10 @@ kron_nd(a1::AbstractMatrix, a2::AbstractMatrix) = kron(a1, a2)
235247
kron_nd(a1::AbstractVector, a2::AbstractVector) = kron(a1, a2)
236248

237249
# Eagerly collect arguments to make more general on GPU.
238-
Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
250+
Base.collect(a::AbstractKroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
251+
Base.collect(T::Type, a::AbstractKroneckerArray) = kron_nd(collect(T, arg1(a)), collect(T, arg2(a)))
239252

240-
function Base.zero(a::KroneckerArray)
253+
function Base.zero(a::AbstractKroneckerArray)
241254
return if isactive(arg1(a)) == isactive(arg2(a))
242255
# TODO: Maybe this should zero both arguments?
243256
# This is how `a * false` would behave.
@@ -250,35 +263,28 @@ function Base.zero(a::KroneckerArray)
250263
end
251264

252265
using DerivableInterfaces: DerivableInterfaces, zero!
253-
function DerivableInterfaces.zero!(a::KroneckerArray)
266+
function DerivableInterfaces.zero!(a::AbstractKroneckerArray)
254267
(isactive(arg1(a)) || isactive(arg2(a))) ||
255268
error("Can't mutate immutable KroneckerArray.")
256269
isactive(arg1(a)) && zero!(arg1(a))
257270
isactive(arg2(a)) && zero!(arg2(a))
258271
return a
259272
end
260273

261-
function Base.Array{T, N}(a::KroneckerArray{S, N}) where {T, S, N}
262-
return convert(Array{T, N}, collect(a))
274+
function Base.Array{T, N}(a::AbstractKroneckerArray{S, N}) where {T, S, N}
275+
return convert(Array{T, N}, collect(T, a))
263276
end
264277

265-
function Base.size(a::KroneckerArray)
266-
return ntuple(dim -> size(arg1(a), dim) * size(arg2(a), dim), ndims(a))
267-
end
278+
Base.size(a::AbstractKroneckerArray) = size(arg1(a)) .* size(arg2(a))
268279

269-
function Base.axes(a::KroneckerArray)
280+
function Base.axes(a::AbstractKroneckerArray)
270281
return ntuple(ndims(a)) do dim
271282
return CartesianProductUnitRange(
272283
axes(arg1(a), dim) × axes(arg2(a), dim), Base.OneTo(size(a, dim))
273284
)
274285
end
275286
end
276287

277-
arguments(a::KroneckerArray) = (arg1(a), arg2(a))
278-
arguments(a::KroneckerArray, n::Int) = arguments(a)[n]
279-
argument_types(a::KroneckerArray) = argument_types(typeof(a))
280-
argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2)
281-
282288
function Base.print_array(io::IO, a::KroneckerArray)
283289
Base.print_array(io, arg1(a))
284290
println(io, "\n")
@@ -312,45 +318,48 @@ end
312318

313319
# Indexing logic.
314320
function Base.to_indices(
315-
a::KroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
321+
a::AbstractKroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
316322
)
317323
I1 = to_indices(arg1(a), arg1.(inds), arg1.(I))
318324
I2 = to_indices(arg2(a), arg2.(inds), arg2.(I))
319325
return I1 I2
320326
end
321327

322328
function Base.getindex(
323-
a::KroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N}
329+
a::AbstractKroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N}
324330
) where {N}
325331
I′ = to_indices(a, I)
326332
return arg1(a)[arg1.(I′)...] arg2(a)[arg2.(I′)...]
327333
end
328334
# Fix ambigiuity error.
329-
Base.getindex(a::KroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[]
335+
Base.getindex(a::AbstractKroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[]
330336

331337
arg1(::Colon) = (:)
332338
arg2(::Colon) = (:)
333339
arg1(::Base.Slice) = (:)
334340
arg2(::Base.Slice) = (:)
335341
function Base.view(
336-
a::KroneckerArray{<:Any, N},
342+
a::AbstractKroneckerArray{<:Any, N},
337343
I::Vararg{Union{CartesianProduct, CartesianProductUnitRange, Base.Slice, Colon}, N},
338344
) where {N}
339345
return view(arg1(a), arg1.(I)...) view(arg2(a), arg2.(I)...)
340346
end
341-
function Base.view(a::KroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N}
347+
function Base.view(a::AbstractKroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N}
342348
return view(arg1(a), arg1.(I)...) view(arg2(a), arg2.(I)...)
343349
end
344350
# Fix ambigiuity error.
345-
Base.view(a::KroneckerArray{<:Any, 0}) = view(arg1(a)) view(arg2(a))
351+
Base.view(a::AbstractKroneckerArray{<:Any, 0}) = view(arg1(a)) view(arg2(a))
346352

347-
function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
353+
function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
348354
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
349355
end
350-
function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...)
356+
357+
# TODO: this definition doesn't fully retain the original meaning:
358+
# ‖a - b‖ < atol could be true even if the following check isn't
359+
function Base.isapprox(a::AbstractKroneckerArray, b::AbstractKroneckerArray; kwargs...)
351360
return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...)
352361
end
353-
function Base.iszero(a::KroneckerArray)
362+
function Base.iszero(a::AbstractKroneckerArray)
354363
return iszero(arg1(a)) || iszero(arg2(a))
355364
end
356365
function Base.isreal(a::KroneckerArray)
@@ -362,17 +371,17 @@ function DiagonalArrays.diagonal(a::KroneckerArray)
362371
return diagonal(arg1(a)) diagonal(arg2(a))
363372
end
364373

365-
Base.real(a::KroneckerArray{<:Real}) = a
366-
function Base.real(a::KroneckerArray)
374+
Base.real(a::AbstractKroneckerArray{<:Real}) = a
375+
function Base.real(a::AbstractKroneckerArray)
367376
if iszero(imag(arg1(a))) || iszero(imag(arg2(a)))
368377
return real(arg1(a)) real(arg2(a))
369378
elseif iszero(real(arg1(a))) || iszero(real(arg2(a)))
370379
return -(imag(arg1(a)) imag(arg2(a)))
371380
end
372381
return real(arg1(a)) real(arg2(a)) - imag(arg1(a)) imag(arg2(a))
373382
end
374-
Base.imag(a::KroneckerArray{<:Real}) = zero(a)
375-
function Base.imag(a::KroneckerArray)
383+
Base.imag(a::AbstractKroneckerArray{<:Real}) = zero(a)
384+
function Base.imag(a::AbstractKroneckerArray)
376385
if iszero(imag(arg1(a))) || iszero(real(arg2(a)))
377386
return real(arg1(a)) imag(arg2(a))
378387
elseif iszero(real(arg1(a))) || iszero(imag(arg2(a)))
@@ -383,14 +392,14 @@ end
383392

384393
for f in [:transpose, :adjoint, :inv]
385394
@eval begin
386-
function Base.$f(a::KroneckerArray)
395+
function Base.$f(a::AbstractKroneckerArray)
387396
return $f(arg1(a)) $f(arg2(a))
388397
end
389398
end
390399
end
391400

392401
function Base.reshape(
393-
a::KroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
402+
a::AbstractKroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
394403
)
395404
return reshape(arg1(a), map(arg1, ax)) reshape(arg2(a), map(arg2, ax))
396405
end
@@ -410,8 +419,8 @@ end
410419
function KroneckerStyle{N, A1, A2}(v::Val{M}) where {N, A1, A2, M}
411420
return KroneckerStyle{M, typeof(A1)(v), typeof(A2)(v)}()
412421
end
413-
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any, N, A1, A2}}) where {N, A1, A2}
414-
return KroneckerStyle{N}(BroadcastStyle(A1), BroadcastStyle(A2))
422+
function Base.BroadcastStyle(::Type{T}) where {T <: AbstractKroneckerArray}
423+
return KroneckerStyle{ndims(T)}(BroadcastStyle(arg1type(T)), BroadcastStyle(arg2type(T)))
415424
end
416425
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
417426
style_a = BroadcastStyle(arg1(style1), arg1(style2))
@@ -430,10 +439,10 @@ function Base.similar(
430439
return a b
431440
end
432441

433-
function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...)
442+
function Base.map(f, a1::AbstractKroneckerArray, a_rest::AbstractKroneckerArray...)
434443
return Broadcast.broadcast_preserving_zero_d(f, a1, a_rest...)
435444
end
436-
function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...)
445+
function Base.map!(f, dest::AbstractKroneckerArray, a1::AbstractKroneckerArray, a_rest::AbstractKroneckerArray...)
437446
dest .= f.(a1, a_rest...)
438447
return dest
439448
end
@@ -465,7 +474,7 @@ end
465474
function Base.copy(a::Summed{<:KroneckerStyle})
466475
return copy(KroneckerBroadcast(a))
467476
end
468-
function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle})
477+
function Base.copyto!(dest::AbstractKroneckerArray, a::Summed{<:KroneckerStyle})
469478
return copyto!(dest, KroneckerBroadcast(a))
470479
end
471480

0 commit comments

Comments
 (0)