Skip to content

Commit c029980

Browse files
committed
Update argument names
1 parent ab22c2f commit c029980

File tree

8 files changed

+235
-188
lines changed

8 files changed

+235
-188
lines changed

src/cartesianproduct.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
struct CartesianPair{A,B}
2-
a::A
3-
b::B
1+
struct CartesianPair{A1,A2}
2+
arg1::A1
3+
arg2::A2
44
end
5-
arguments(a::CartesianPair) = (a.a, a.b)
5+
arguments(a::CartesianPair) = (arg1(a), arg2(a))
66
arguments(a::CartesianPair, n::Int) = arguments(a)[n]
77

8-
arg1(a::CartesianPair) = a.a
9-
arg2(a::CartesianPair) = a.b
8+
arg1(a::CartesianPair) = getfield(a, :arg1)
9+
arg2(a::CartesianPair) = getfield(a, :arg2)
1010

1111
×(a, b) = CartesianPair(a, b)
1212

1313
function Base.show(io::IO, a::CartesianPair)
14-
print(io, a.a, " × ", a.b)
14+
print(io, arg1(a), " × ", arg2(a))
1515
return nothing
1616
end
1717

@@ -20,16 +20,16 @@ struct CartesianProduct{TA,TB,A<:AbstractVector{TA},B<:AbstractVector{TB}} <:
2020
a::A
2121
b::B
2222
end
23-
arguments(a::CartesianProduct) = (a.a, a.b)
23+
arguments(a::CartesianProduct) = (arg1(a), arg2(a))
2424
arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
2525

26-
arg1(a::CartesianProduct) = a.a
27-
arg2(a::CartesianProduct) = a.b
26+
arg1(a::CartesianProduct) = getfield(a, :a)
27+
arg2(a::CartesianProduct) = getfield(a, :b)
2828

2929
Base.copy(a::CartesianProduct) = copy(arg1(a)) × copy(arg2(a))
3030

3131
function Base.show(io::IO, a::CartesianProduct)
32-
print(io, a.a, " × ", a.b)
32+
print(io, arg1(a), " × ", arg2(a))
3333
return nothing
3434
end
3535
function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct)
@@ -38,7 +38,7 @@ function Base.show(io::IO, ::MIME"text/plain", a::CartesianProduct)
3838
end
3939

4040
×(a::AbstractVector, b::AbstractVector) = CartesianProduct(a, b)
41-
Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
41+
Base.length(a::CartesianProduct) = length(arg1(a)) * length(arg2(a))
4242
Base.size(a::CartesianProduct) = (length(a),)
4343

4444
function Base.getindex(a::CartesianProduct, i::CartesianProduct)

src/kroneckerarray.jl

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ function _convert(A::Type{<:Diagonal}, a::AbstractMatrix)
2525
return isdiag(a) ? _construct(A, a) : throw(InexactError(:convert, A, a))
2626
end
2727

28-
struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N}
29-
a::A
30-
b::B
28+
struct KroneckerArray{T,N,A1<:AbstractArray{T,N},A2<:AbstractArray{T,N}} <:
29+
AbstractArray{T,N}
30+
arg1::A1
31+
arg2::A2
3132
end
3233
function KroneckerArray(a::AbstractArray, b::AbstractArray)
3334
if ndims(a) != ndims(b)
@@ -38,11 +39,15 @@ function KroneckerArray(a::AbstractArray, b::AbstractArray)
3839
elt = promote_type(eltype(a), eltype(b))
3940
return _convert(AbstractArray{elt}, a) _convert(AbstractArray{elt}, b)
4041
end
41-
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
42-
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}
42+
const KroneckerMatrix{T,A1<:AbstractMatrix{T},A2<:AbstractMatrix{T}} = KroneckerArray{
43+
T,2,A1,A2
44+
}
45+
const KroneckerVector{T,A1<:AbstractVector{T},A2<:AbstractVector{T}} = KroneckerArray{
46+
T,1,A1,A2
47+
}
4348

44-
arg1(a::KroneckerArray) = a.a
45-
arg2(a::KroneckerArray) = a.b
49+
@inline arg1(a::KroneckerArray) = getfield(a, :arg1)
50+
@inline arg2(a::KroneckerArray) = getfield(a, :arg2)
4651

4752
function mutate_active_args!(f!, f, dest, src)
4853
(isactive(arg1(dest)) || isactive(arg2(dest))) ||
@@ -81,8 +86,10 @@ function Base.copyto!(dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N
8186
return mutate_active_args!(copyto!, copy, dest, src)
8287
end
8388

84-
function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where {T,N,A,B}
85-
return _convert(A, arg1(a)) _convert(B, arg2(a))
89+
function Base.convert(
90+
::Type{KroneckerArray{T,N,A1,A2}}, a::KroneckerArray
91+
) where {T,N,A1,A2}
92+
return _convert(A1, arg1(a)) _convert(A2, arg2(a))
8693
end
8794

8895
# Promote the element type if needed.
@@ -140,17 +147,17 @@ function Base.similar(
140147
end
141148

142149
function Base.similar(
143-
arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}},
150+
arrayt::Type{<:KroneckerArray{<:Any,<:Any,A1,A2}},
144151
axs::Tuple{
145152
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
146153
},
147-
) where {A,B}
148-
return similar(A, map(arg1, axs)) similar(B, map(arg2, axs))
154+
) where {A1,A2}
155+
return similar(A1, map(arg1, axs)) similar(A2, map(arg2, axs))
149156
end
150157
function Base.similar(
151-
::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, sz::Tuple{Int,Vararg{Int}}
152-
) where {A,B}
153-
return similar(promote_type(A, B), sz)
158+
::Type{<:KroneckerArray{<:Any,<:Any,A1,A2}}, sz::Tuple{Int,Vararg{Int}}
159+
) where {A1,A2}
160+
return similar(promote_type(A1, A2), sz)
154161
end
155162

156163
function Base.similar(
@@ -243,7 +250,7 @@ end
243250
arguments(a::KroneckerArray) = (arg1(a), arg2(a))
244251
arguments(a::KroneckerArray, n::Int) = arguments(a)[n]
245252
argument_types(a::KroneckerArray) = argument_types(typeof(a))
246-
argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A,B}}) where {A,B} = (A, B)
253+
argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A1,A2}}) where {A1,A2} = (A1, A2)
247254

248255
function Base.print_array(io::IO, a::KroneckerArray)
249256
Base.print_array(io, arg1(a))
@@ -362,22 +369,22 @@ function Base.reshape(
362369
end
363370

364371
using Base.Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted
365-
struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
366-
arg1(::Type{<:KroneckerStyle{<:Any,A}}) where {A} = A
372+
struct KroneckerStyle{N,A1,A2} <: AbstractArrayStyle{N} end
373+
arg1(::Type{<:KroneckerStyle{<:Any,A1}}) where {A1} = A1
367374
arg1(style::KroneckerStyle) = arg1(typeof(style))
368-
arg2(::Type{<:KroneckerStyle{<:Any,B}}) where {B} = B
375+
arg2(::Type{<:KroneckerStyle{<:Any,<:Any,A2}}) where {A2} = A2
369376
arg2(style::KroneckerStyle) = arg2(typeof(style))
370377
function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N}
371378
return KroneckerStyle{N,a,b}()
372379
end
373380
function KroneckerStyle(a::AbstractArrayStyle{N}, b::AbstractArrayStyle{N}) where {N}
374381
return KroneckerStyle{N}(a, b)
375382
end
376-
function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
377-
return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}()
383+
function KroneckerStyle{N,A1,A2}(v::Val{M}) where {N,A1,A2,M}
384+
return KroneckerStyle{M,typeof(A1)(v),typeof(A2)(v)}()
378385
end
379-
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B}
380-
return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B))
386+
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A1,A2}}) where {N,A1,A2}
387+
return KroneckerStyle{N}(BroadcastStyle(A1), BroadcastStyle(A2))
381388
end
382389
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
383390
style_a = BroadcastStyle(arg1(style1), arg1(style2))
@@ -386,9 +393,11 @@ function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N
386393
(style_b isa Broadcast.Unknown) && return Broadcast.Unknown()
387394
return KroneckerStyle{N}(style_a, style_b)
388395
end
389-
function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type, ax) where {N,A,B}
390-
bc_a = Broadcasted(A, bc.f, arg1.(bc.args), arg1.(ax))
391-
bc_b = Broadcasted(B, bc.f, arg2.(bc.args), arg2.(ax))
396+
function Base.similar(
397+
bc::Broadcasted{<:KroneckerStyle{N,A1,A2}}, elt::Type, ax
398+
) where {N,A1,A2}
399+
bc_a = Broadcasted(A1, bc.f, arg1.(bc.args), arg1.(ax))
400+
bc_b = Broadcasted(A2, bc.f, arg2.(bc.args), arg2.(ax))
392401
a = similar(bc_a, elt)
393402
b = similar(bc_b, elt)
394403
return a b
@@ -497,12 +506,12 @@ using Base.Broadcast: broadcasted
497506
# Represents broadcast operations that can be applied Kronecker-wise,
498507
# i.e. independently to each argument of the Kronecker product.
499508
# Note that not all broadcast operations can be mapped to this.
500-
struct KroneckerBroadcasted{A,B}
501-
a::A
502-
b::B
509+
struct KroneckerBroadcasted{A1,A2}
510+
arg1::A1
511+
arg2::A2
503512
end
504-
arg1(a::KroneckerBroadcasted) = a.a
505-
arg2(a::KroneckerBroadcasted) = a.b
513+
@inline arg1(a::KroneckerBroadcasted) = getfield(a, :arg1)
514+
@inline arg2(a::KroneckerBroadcasted) = getfield(a, :arg2)
506515
(a::Broadcasted, b::Broadcasted) = KroneckerBroadcasted(a, b)
507516
(a::Broadcasted, b) = KroneckerBroadcasted(a, b)
508517
(a, b::Broadcasted) = KroneckerBroadcasted(a, b)
@@ -525,18 +534,20 @@ function Base.axes(a::KroneckerBroadcasted)
525534
end
526535

527536
function Base.BroadcastStyle(
528-
::Type{<:KroneckerBroadcasted{A,B}}
529-
) where {StyleA,StyleB,A<:Broadcasted{StyleA},B<:Broadcasted{StyleB}}
530-
@assert ndims(A) == ndims(B)
531-
N = ndims(A)
532-
return KroneckerStyle{N}(StyleA(), StyleB())
537+
::Type{<:KroneckerBroadcasted{A1,A2}}
538+
) where {StyleA1,StyleA2,A1<:Broadcasted{StyleA1},A2<:Broadcasted{StyleA2}}
539+
@assert ndims(A1) == ndims(A2)
540+
N = ndims(A1)
541+
return KroneckerStyle{N}(StyleA1(), StyleA2())
533542
end
534543

535544
# Operations that preserve the Kronecker structure.
536545
for f in [:identity, :conj]
537546
@eval begin
538-
function Broadcast.broadcasted(::KroneckerStyle{<:Any,A,B}, ::typeof($f), a) where {A,B}
539-
return broadcasted(A, $f, arg1(a)) broadcasted(B, $f, arg2(a))
547+
function Broadcast.broadcasted(
548+
::KroneckerStyle{<:Any,A1,A2}, ::typeof($f), a
549+
) where {A1,A2}
550+
return broadcasted(A1, $f, arg1(a)) broadcasted(A2, $f, arg2(a))
540551
end
541552
end
542553
end

src/linearalgebra.jl

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using DiagonalArrays: δ
12
using LinearAlgebra:
23
LinearAlgebra,
34
Diagonal,
@@ -17,7 +18,7 @@ using LinearAlgebra:
1718

1819
using LinearAlgebra: LinearAlgebra
1920
function KroneckerArray(J::LinearAlgebra.UniformScaling, ax::Tuple)
20-
return Eye{eltype(J)}(arg1.(ax)) Eye{eltype(J)}(arg2.(ax))
21+
return δ(eltype(J), arg1.(ax)) δ(eltype(J), arg2.(ax))
2122
end
2223
function Base.copyto!(a::KroneckerArray, J::LinearAlgebra.UniformScaling)
2324
copyto!(a, KroneckerArray(J, axes(a)))
@@ -26,21 +27,15 @@ end
2627

2728
using LinearAlgebra: LinearAlgebra, pinv
2829
function LinearAlgebra.pinv(a::KroneckerArray; kwargs...)
29-
return pinv(a.a; kwargs...) pinv(a.b; kwargs...)
30+
return pinv(arg1(a); kwargs...) pinv(arg2(a); kwargs...)
3031
end
3132

3233
function LinearAlgebra.diag(a::KroneckerArray)
3334
return copy(DiagonalArrays.diagview(a))
3435
end
3536

36-
# Allows customizing multiplication for specific types
37-
# such as `Eye * Eye`, which doesn't return `Eye`.
38-
function _mul(a::AbstractArray, b::AbstractArray)
39-
return a * b
40-
end
41-
4237
function Base.:*(a::KroneckerArray, b::KroneckerArray)
43-
return _mul(a.a, b.a) _mul(a.b, b.b)
38+
return (arg1(a) * arg1(b)) (arg2(a) * arg2(b))
4439
end
4540

4641
function LinearAlgebra.mul!(
@@ -53,17 +48,20 @@ function LinearAlgebra.mul!(
5348
"Can't multiple KroneckerArrays with nonzero β and nonzero destination."
5449
),
5550
)
56-
mul!(c.a, a.a, b.a)
57-
mul!(c.b, a.b, b.b, α, β)
51+
# TODO: Only perform in-place operation on the non-active argument(s).
52+
mul!(arg1(c), arg1(a), arg1(b))
53+
mul!(arg2(c), arg2(a), arg2(b), α, β)
5854
return c
5955
end
6056

57+
using LinearAlgebra: tr
6158
function LinearAlgebra.tr(a::KroneckerArray)
62-
return tr(a.a) tr(a.b)
59+
return tr(arg1(a)) * tr(arg2(a))
6360
end
6461

62+
using LinearAlgebra: norm
6563
function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
66-
return norm(a.a, p) norm(a.b, p)
64+
return norm(arg1(a), p) * norm(arg2(a), p)
6765
end
6866

6967
# Matrix functions
@@ -113,45 +111,54 @@ for f in MATRIX_FUNCTIONS
113111
end
114112
end
115113

116-
using LinearAlgebra: checksquare
114+
# `DiagonalArrays.issquare` and `DiagonalArrays.checksquare` are more general
115+
# than `LinearAlgebra.checksquare`, for example it compares axes and can check
116+
# that the codomain and domain are dual of each other.
117+
using DiagonalArrays: DiagonalArrays, checksquare, issquare
118+
function DiagonalArrays.issquare(a::KroneckerArray)
119+
return issquare(arg1(a)) && issquare(arg2(a))
120+
end
121+
122+
using LinearAlgebra: det
117123
function LinearAlgebra.det(a::KroneckerArray)
118-
checksquare(a.a)
119-
checksquare(a.b)
120-
return det(a.a) ^ size(a.b, 1) * det(a.b) ^ size(a.a, 1)
124+
checksquare(a)
125+
return det(arg1(a)) ^ size(arg2(a), 1) * det(arg2(a)) ^ size(arg1(a), 1)
121126
end
122127

123128
function LinearAlgebra.svd(a::KroneckerArray)
124-
Fa = svd(a.a)
125-
Fb = svd(a.b)
126-
return SVD(Fa.U Fb.U, Fa.S Fb.S, Fa.Vt Fb.Vt)
129+
F1 = svd(arg1(a))
130+
F2 = svd(arg2(a))
131+
return SVD(F1.U F2.U, F1.S F2.S, F1.Vt F2.Vt)
127132
end
128133
function LinearAlgebra.svdvals(a::KroneckerArray)
129-
return svdvals(a.a) svdvals(a.b)
134+
return svdvals(arg1(a)) svdvals(arg2(a))
130135
end
131136
function LinearAlgebra.eigen(a::KroneckerArray)
132-
Fa = eigen(a.a)
133-
Fb = eigen(a.b)
134-
return Eigen(Fa.values Fb.values, Fa.vectors Fb.vectors)
137+
F1 = eigen(arg1(a))
138+
F2 = eigen(arg2(a))
139+
return Eigen(F1.values F2.values, F1.vectors F2.vectors)
135140
end
136141
function LinearAlgebra.eigvals(a::KroneckerArray)
137-
return eigvals(a.a) eigvals(a.b)
142+
return eigvals(arg1(a)) eigvals(arg2(a))
138143
end
139144

140-
struct KroneckerQ{A,B}
141-
a::A
142-
b::B
145+
struct KroneckerQ{A1,A2}
146+
arg1::A1
147+
arg2::A2
143148
end
149+
@inline arg1(a::KroneckerQ) = getfield(a, :arg1)
150+
@inline arg2(a::KroneckerQ) = getfield(a, :arg2)
144151
function Base.:*(a::KroneckerQ, b::KroneckerQ)
145-
return (a.a * b.a) (a.b * b.b)
152+
return (arg1(a) * arg1(b)) (arg2(a) * arg2(b))
146153
end
147-
function Base.:*(a::KroneckerQ, b::KroneckerArray)
148-
return (a.a * b.a) (a.b * b.b)
154+
function Base.:*(a1::KroneckerQ, a2::KroneckerArray)
155+
return (arg1(a1) * arg1(a2)) (arg2(a1) * arg2(a2))
149156
end
150-
function Base.:*(a::KroneckerArray, b::KroneckerQ)
151-
return (a.a * b.a) (a.b * b.b)
157+
function Base.:*(a1::KroneckerArray, a2::KroneckerQ)
158+
return (arg1(a1) * arg1(a2)) (arg2(a1) * arg2(a2))
152159
end
153160
function Base.adjoint(a::KroneckerQ)
154-
return KroneckerQ(a.a', a.b')
161+
return KroneckerQ(arg1(a)', arg2(a)')
155162
end
156163

157164
struct KroneckerQR{QQ,RR}
@@ -165,8 +172,8 @@ function ⊗(a::LinearAlgebra.QRCompactWYQ, b::LinearAlgebra.QRCompactWYQ)
165172
return KroneckerQ(a, b)
166173
end
167174
function LinearAlgebra.qr(a::KroneckerArray)
168-
Fa = qr(a.a)
169-
Fb = qr(a.b)
175+
Fa = qr(arg1(a))
176+
Fb = qr(arg2(a))
170177
return KroneckerQR(Fa.Q Fb.Q, Fa.R Fb.R)
171178
end
172179

@@ -181,7 +188,7 @@ function ⊗(a::LinearAlgebra.LQPackedQ, b::LinearAlgebra.LQPackedQ)
181188
return KroneckerQ(a, b)
182189
end
183190
function LinearAlgebra.lq(a::KroneckerArray)
184-
Fa = lq(a.a)
185-
Fb = lq(a.b)
191+
Fa = lq(arg1(a))
192+
Fb = lq(arg2(a))
186193
return KroneckerLQ(Fa.L Fb.L, Fa.Q Fb.Q)
187194
end

0 commit comments

Comments
 (0)