Skip to content

WIP: review before version 3 #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 61 additions & 45 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
struct KroneckerMap{T, As<:Tuple{Vararg{LinearMap}}} <: LinearMap{T}
struct KroneckerMap{T, As<:LinearMapTuple} <: LinearMap{T}
maps::As
function KroneckerMap{T, As}(maps::As) where {T, As}
for n in eachindex(maps)
A = maps[n]
@assert promote_type(T, eltype(A)) == T "eltype $(eltype(A)) cannot be promoted to $T in KroneckerMap constructor"
function KroneckerMap{T}(maps::LinearMapTuple) where {T}
for TA in Base.Generator(eltype, maps)
promote_type(T, TA) == T ||
error("eltype $TA cannot be promoted to $T in KroneckerMap constructor")
end
return new{T,As}(maps)
return new{T,typeof(maps)}(maps)
end
end

KroneckerMap{T}(maps::As) where {T, As<:Tuple{Vararg{LinearMap}}} = KroneckerMap{T, As}(maps)

"""
kron(A::LinearMap, B::LinearMap)::KroneckerMap
kron(A, B, Cs...)::KroneckerMap
Expand Down Expand Up @@ -46,24 +44,30 @@ julia> Matrix(Δ)
0 1 1 -4
```
"""
Base.kron(A::LinearMap, B::LinearMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}((A, B))
Base.kron(A::LinearMap, B::KroneckerMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A, B.maps...))
Base.kron(A::KroneckerMap, B::LinearMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B))
Base.kron(A::KroneckerMap, B::KroneckerMap) = KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B.maps...))
Base.kron(A::LinearMap, B::LinearMap, Cs::LinearMap...) = KroneckerMap{promote_type(eltype(A), eltype(B), map(eltype, Cs)...)}(tuple(A, B, Cs...))
Base.kron(A::LinearMap, B::LinearMap) =
KroneckerMap{promote_type(eltype(A), eltype(B))}((A, B))
Base.kron(A::LinearMap, B::KroneckerMap) =
KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A, B.maps...))
Base.kron(A::KroneckerMap, B::LinearMap) =
KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B))
Base.kron(A::KroneckerMap, B::KroneckerMap) =
KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B.maps...))
Base.kron(A::LinearMap, B::LinearMap, Cs::LinearMap...) =
KroneckerMap{promote_type(eltype(A), eltype(B), map(eltype, Cs)...)}(tuple(A, B, Cs...))
# could this just be a recursive definition: kron(A, B, C) = kron(kron(A, B), C) ?
Base.kron(A::AbstractMatrix, B::LinearMap) = kron(LinearMap(A), B)
Base.kron(A::LinearMap, B::AbstractMatrix) = kron(A, LinearMap(B))
# promote AbstractMatrix arguments to LinearMaps, then take LinearMap-Kronecker product
for k in 3:8 # is 8 sufficient?
Is = ntuple(n->:($(Symbol(:A,n))::AbstractMatrix), Val(k-1))
Is = ntuple(n->:($(Symbol(:A, n))::AbstractMatrix), Val(k-1))
# yields (:A1, :A2, :A3, ..., :A(k-1))
L = :($(Symbol(:A,k))::LinearMap)
L = :($(Symbol(:A, k))::LinearMap)
# yields :Ak::LinearMap
mapargs = ntuple(n -> :(LinearMap($(Symbol(:A,n)))), Val(k-1))
mapargs = ntuple(n -> :(LinearMap($(Symbol(:A, n)))), Val(k-1))
# yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1)))

@eval Base.kron($(Is...), $L, As::MapOrMatrix...) =
kron($(mapargs...), $(Symbol(:A,k)), convert_to_lmaps(As...)...)
kron($(mapargs...), $(Symbol(:A, k)), convert_to_lmaps(As...)...)
end

struct KronPower{p}
Expand All @@ -81,10 +85,10 @@ Construct a lazy representation of the `k`-th Kronecker power
"""
⊗(k::Integer) = KronPower(k)

⊗(A, B, Cs...) = KroneckerMap{promote_type(eltype(A), eltype(B), map(eltype, Cs)...)}(convert_to_lmaps(A, B, Cs...))
⊗(A, B, Cs...) = kron(convert_to_lmaps(A, B, Cs...))

Base.:(^)(A::MapOrMatrix, ::KronPower{p}) where {p} =
(ntuple(n -> convert_to_lmaps_(A), Val(p))...)
kron(ntuple(n -> convert_to_lmaps_(A), Val(p))...)

Base.size(A::KroneckerMap) = map(*, size.(A.maps)...)

Expand All @@ -94,8 +98,8 @@ LinearAlgebra.issymmetric(A::KroneckerMap) = all(issymmetric, A.maps)
LinearAlgebra.ishermitian(A::KroneckerMap{<:Real}) = issymmetric(A)
LinearAlgebra.ishermitian(A::KroneckerMap) = all(ishermitian, A.maps)

LinearAlgebra.adjoint(A::KroneckerMap{T}) where {T} = KroneckerMap{T}(map(adjoint, A.maps))
LinearAlgebra.transpose(A::KroneckerMap{T}) where {T} = KroneckerMap{T}(map(transpose, A.maps))
LinearAlgebra.adjoint(A::KroneckerMap) = KroneckerMap{eltype(A)}(map(adjoint, A.maps))
LinearAlgebra.transpose(A::KroneckerMap) = KroneckerMap{eltype(A)}(map(transpose, A.maps))

Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps == B.maps)

Expand All @@ -118,11 +122,11 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps
end
return y
end
@inline function _kronmul!(y, B, X, At::Union{MatrixMap,UniformScalingMap}, T)
@inline function _kronmul!(y, B, X, At::Union{MatrixMap, UniformScalingMap}, T)
na, ma = size(At)
mb, nb = size(B)
Y = reshape(y, (mb, ma))
if nb*ma < mb*na
if nb*ma < mb*na
_unsafe_mul!(Y, B, Matrix(X*At))
else
_unsafe_mul!(Y, Matrix(B*X), parent(At))
Expand All @@ -134,24 +138,30 @@ end
# multiplication with vectors
#################

function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap{T,<:NTuple{2,LinearMap}}, x::AbstractVector) where {T}
const KroneckerMap2{T} = KroneckerMap{T, <:Tuple{LinearMap, LinearMap}}

function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap2, x::AbstractVector)
require_one_based_indexing(y)
A, B = L.maps
X = LinearMap(reshape(x, (size(B, 2), size(A, 2))); issymmetric=false, ishermitian=false, isposdef=false)
_kronmul!(y, B, X, transpose(A), T)
X = LinearMap(reshape(x, (size(B, 2), size(A, 2)));
issymmetric = false, ishermitian = false, isposdef = false)
_kronmul!(y, B, X, transpose(A), eltype(L))
return y
end
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap{T}, x::AbstractVector) where {T}
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap, x::AbstractVector)
require_one_based_indexing(y)
A = first(L.maps)
B = kron(Base.tail(L.maps)...)
X = LinearMap(reshape(x, (size(B, 2), size(A, 2))); issymmetric=false, ishermitian=false, isposdef=false)
_kronmul!(y, B, X, transpose(A), T)
X = LinearMap(reshape(x, (size(B, 2), size(A, 2)));
issymmetric = false, ishermitian = false, isposdef = false)
_kronmul!(y, B, X, transpose(A), eltype(L))
return y
end
# mixed-product rule, prefer the right if possible
# (A₁ ⊗ A₂ ⊗ ... ⊗ Aᵣ) * (B₁ ⊗ B₂ ⊗ ... ⊗ Bᵣ) = (A₁B₁) ⊗ (A₂B₂) ⊗ ... ⊗ (AᵣBᵣ)
function _unsafe_mul!(y::AbstractVecOrMat, L::CompositeMap{<:Any,<:Tuple{KroneckerMap,KroneckerMap}}, x::AbstractVector)
function _unsafe_mul!(y::AbstractVecOrMat,
L::CompositeMap{<:Any,<:Tuple{KroneckerMap,KroneckerMap}},
x::AbstractVector)
require_one_based_indexing(y)
B, A = L.maps
if length(A.maps) == length(B.maps) && all(_iscompatible, zip(A.maps, B.maps))
Expand All @@ -162,8 +172,10 @@ function _unsafe_mul!(y::AbstractVecOrMat, L::CompositeMap{<:Any,<:Tuple{Kroneck
return y
end
# mixed-product rule, prefer the right if possible
# (A₁ ⊗ B₁)*(A₂⊗B₂)*...*(Aᵣ⊗Bᵣ) = (A₁*A₂*...*Aᵣ) ⊗ (B₁*B₂*...*Bᵣ)
function _unsafe_mul!(y::AbstractVecOrMat, L::CompositeMap{T,<:Tuple{Vararg{KroneckerMap{<:Any,<:Tuple{LinearMap,LinearMap}}}}}, x::AbstractVector) where {T}
# (A₁⊗B₁) * (A₂⊗B₂) * ... * (Aᵣ⊗Bᵣ) = (A₁*A₂*...*Aᵣ) ⊗ (B₁*B₂*...*Bᵣ)
function _unsafe_mul!(y::AbstractVecOrMat,
L::CompositeMap{T, <:Tuple{Vararg{KroneckerMap2}}},
x::AbstractVector) where {T}
require_one_based_indexing(y)
As = map(AB -> AB.maps[1], L.maps)
Bs = map(AB -> AB.maps[2], L.maps)
Expand All @@ -181,20 +193,20 @@ end
###############
# KroneckerSumMap
###############
struct KroneckerSumMap{T, As<:Tuple{LinearMap,LinearMap}} <: LinearMap{T}
struct KroneckerSumMap{T, As<:Tuple{LinearMap, LinearMap}} <: LinearMap{T}
maps::As
function KroneckerSumMap{T, As}(maps::As) where {T, As}
for n in eachindex(maps)
A = maps[n]
size(A, 1) == size(A, 2) || throw(ArgumentError("operators need to be square in Kronecker sums"))
@assert promote_type(T, eltype(A)) == T "eltype $(eltype(A)) cannot be promoted to $T in KroneckerSumMap constructor"
function KroneckerSumMap{T}(maps::Tuple{LinearMap,LinearMap}) where {T}
A1, A2 = maps
(size(A1, 1) == size(A1, 2) && size(A2, 1) == size(A2, 2)) ||
throw(ArgumentError("operators need to be square in Kronecker sums"))
for TA in Base.Generator(eltype, maps)
promote_type(T, TA) == T ||
error("eltype $TA cannot be promoted to $T in KroneckerSumMap constructor")
end
return new{T,As}(maps)
return new{T, typeof(maps)}(maps)
end
end

KroneckerSumMap{T}(maps::As) where {T, As<:Tuple{LinearMap,LinearMap}} = KroneckerSumMap{T, As}(maps)

"""
kronsum(A, B)::KroneckerSumMap
kronsum(A, B, Cs...)::KroneckerSumMap
Expand Down Expand Up @@ -246,7 +258,8 @@ where `A` can be a square `AbstractMatrix` or a `LinearMap`.

⊕(a, b, c...) = kronsum(a, b, c...)

Base.:(^)(A::MapOrMatrix, ::KronSumPower{p}) where {p} = kronsum(ntuple(n -> convert_to_lmaps_(A), Val(p))...)
Base.:(^)(A::MapOrMatrix, ::KronSumPower{p}) where {p} =
kronsum(ntuple(n->convert_to_lmaps_(A), Val(p))...)

Base.size(A::KroneckerSumMap, i) = prod(size.(A.maps, i))
Base.size(A::KroneckerSumMap) = (size(A, 1), size(A, 2))
Expand All @@ -256,10 +269,13 @@ LinearAlgebra.issymmetric(A::KroneckerSumMap) = all(issymmetric, A.maps)
LinearAlgebra.ishermitian(A::KroneckerSumMap{<:Real}) = all(issymmetric, A.maps)
LinearAlgebra.ishermitian(A::KroneckerSumMap) = all(ishermitian, A.maps)

LinearAlgebra.adjoint(A::KroneckerSumMap{T}) where {T} = KroneckerSumMap{T}(map(adjoint, A.maps))
LinearAlgebra.transpose(A::KroneckerSumMap{T}) where {T} = KroneckerSumMap{T}(map(transpose, A.maps))
LinearAlgebra.adjoint(A::KroneckerSumMap) =
KroneckerSumMap{eltype(A)}(map(adjoint, A.maps))
LinearAlgebra.transpose(A::KroneckerSumMap) =
KroneckerSumMap{eltype(A)}(map(transpose, A.maps))

Base.:(==)(A::KroneckerSumMap, B::KroneckerSumMap) = (eltype(A) == eltype(B) && A.maps == B.maps)
Base.:(==)(A::KroneckerSumMap, B::KroneckerSumMap) =
(eltype(A) == eltype(B) && A.maps == B.maps)

function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerSumMap, x::AbstractVector)
A, B = L.maps
Expand Down
2 changes: 1 addition & 1 deletion test/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
LB = LinearMap(B)
LK = @inferred kron(LA, LB)
@test parent(LK) == (LA, LB)
@test_throws AssertionError LinearMaps.KroneckerMap{Float64}((LA, LB))
@test_throws ErrorException LinearMaps.KroneckerMap{Float64}((LA, LB))
@test occursin("6×6 LinearMaps.KroneckerMap{$(eltype(LK))}", sprint((t, s) -> show(t, "text/plain", s), LK))
@test @inferred size(LK) == size(K)
@test LinearMaps.MulStyle(LK) === LinearMaps.ThreeArg()
Expand Down