|
11 | 11 | arguments(a::CartesianProduct) = (a.a, a.b)
|
12 | 12 | arguments(a::CartesianProduct, n::Int) = arguments(a)[n]
|
13 | 13 |
|
| 14 | +function Base.show(io::IO, a::CartesianProduct) |
| 15 | + print(io, a.a, " × ", a.b) |
| 16 | + return nothing |
| 17 | +end |
| 18 | + |
14 | 19 | ×(a, b) = CartesianProduct(a, b)
|
15 | 20 | Base.length(a::CartesianProduct) = length(a.a) * length(a.b)
|
16 | 21 | Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b]
|
@@ -130,6 +135,8 @@ function interleave(x::Tuple, y::Tuple)
|
130 | 135 | xy = ntuple(i -> (x[i], y[i]), length(x))
|
131 | 136 | return flatten(xy)
|
132 | 137 | end
|
| 138 | +# TODO: Maybe use scalar indexing based on KroneckerProducts.jl logic for cartesian indexing: |
| 139 | +# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 |
133 | 140 | function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
|
134 | 141 | a′ = reshape(a, interleave(size(a), ntuple(one, N)))
|
135 | 142 | b′ = reshape(b, interleave(ntuple(one, N), size(b)))
|
@@ -183,6 +190,9 @@ function Base.getindex(a::KroneckerArray, i::Integer)
|
183 | 190 | return a[CartesianIndices(a)[i]]
|
184 | 191 | end
|
185 | 192 |
|
| 193 | +# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing |
| 194 | +# in the n-dimensional case and use it to replace the matrix and vector cases: |
| 195 | +# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 |
186 | 196 | function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N}
|
187 | 197 | return error("Not implemented.")
|
188 | 198 | end
|
|
222 | 232 | function Base.inv(a::KroneckerArray)
|
223 | 233 | return inv(a.a) ⊗ inv(a.b)
|
224 | 234 | end
|
| 235 | +using LinearAlgebra: LinearAlgebra, pinv |
| 236 | +function LinearAlgebra.pinv(a::KroneckerArray; kwargs...) |
| 237 | + return pinv(a.a; kwargs...) ⊗ pinv(a.b; kwargs...) |
| 238 | +end |
225 | 239 | function Base.transpose(a::KroneckerArray)
|
226 | 240 | return transpose(a.a) ⊗ transpose(a.b)
|
227 | 241 | end
|
@@ -297,6 +311,7 @@ using LinearAlgebra:
|
297 | 311 | Diagonal,
|
298 | 312 | Eigen,
|
299 | 313 | SVD,
|
| 314 | + det, |
300 | 315 | diag,
|
301 | 316 | eigen,
|
302 | 317 | eigvals,
|
|
335 | 350 | function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
|
336 | 351 | return norm(a.a, p) ⊗ norm(a.b, p)
|
337 | 352 | end
|
| 353 | + |
| 354 | +using MatrixAlgebraKit: MatrixAlgebraKit, diagview |
| 355 | +function MatrixAlgebraKit.diagview(a::KroneckerMatrix) |
| 356 | + return diagview(a.a) ⊗ diagview(a.b) |
| 357 | +end |
338 | 358 | function LinearAlgebra.diag(a::KroneckerArray)
|
339 |
| - return diag(a.a) ⊗ diag(a.b) |
| 359 | + return copy(diagview(a.a)) ⊗ copy(diagview(a.b)) |
| 360 | +end |
| 361 | + |
| 362 | +# Matrix functions |
| 363 | +const MATRIX_FUNCTIONS = [ |
| 364 | + :exp, |
| 365 | + :cis, |
| 366 | + :log, |
| 367 | + :sqrt, |
| 368 | + :cbrt, |
| 369 | + :cos, |
| 370 | + :sin, |
| 371 | + :tan, |
| 372 | + :csc, |
| 373 | + :sec, |
| 374 | + :cot, |
| 375 | + :cosh, |
| 376 | + :sinh, |
| 377 | + :tanh, |
| 378 | + :csch, |
| 379 | + :sech, |
| 380 | + :coth, |
| 381 | + :acos, |
| 382 | + :asin, |
| 383 | + :atan, |
| 384 | + :acsc, |
| 385 | + :asec, |
| 386 | + :acot, |
| 387 | + :acosh, |
| 388 | + :asinh, |
| 389 | + :atanh, |
| 390 | + :acsch, |
| 391 | + :asech, |
| 392 | + :acoth, |
| 393 | +] |
| 394 | + |
| 395 | +for f in MATRIX_FUNCTIONS |
| 396 | + @eval begin |
| 397 | + function Base.$f(a::KroneckerArray) |
| 398 | + return throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported.")) |
| 399 | + end |
| 400 | + end |
| 401 | +end |
| 402 | + |
| 403 | +using LinearAlgebra: checksquare |
| 404 | +function LinearAlgebra.det(a::KroneckerArray) |
| 405 | + checksquare(a.a) |
| 406 | + checksquare(a.b) |
| 407 | + return det(a.a) ^ size(a.b, 1) * det(a.b) ^ size(a.a, 1) |
340 | 408 | end
|
| 409 | + |
341 | 410 | function LinearAlgebra.svd(a::KroneckerArray)
|
342 | 411 | Fa = svd(a.a)
|
343 | 412 | Fb = svd(a.b)
|
@@ -690,18 +759,6 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
|
690 | 759 | end
|
691 | 760 | end
|
692 | 761 |
|
693 |
| -for f in [:eig_trunc!, :eigh_trunc!, :svd_trunc!] |
694 |
| - @eval begin |
695 |
| - function MatrixAlgebraKit.truncate!( |
696 |
| - ::typeof($f), |
697 |
| - (D, V)::Tuple{KroneckerMatrix,KroneckerMatrix}, |
698 |
| - strategy::TruncationStrategy, |
699 |
| - ) |
700 |
| - return throw(MethodError(truncate!, ($f, (D, V), strategy))) |
701 |
| - end |
702 |
| - end |
703 |
| -end |
704 |
| - |
705 | 762 | for f in [:left_orth!, :right_orth!]
|
706 | 763 | @eval begin
|
707 | 764 | function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
|
@@ -941,4 +998,110 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
|
941 | 998 | end
|
942 | 999 | end
|
943 | 1000 |
|
| 1001 | +using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate! |
| 1002 | + |
| 1003 | +struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy |
| 1004 | + strategy::T |
| 1005 | +end |
| 1006 | + |
| 1007 | +# Avoid instantiating the identity. |
| 1008 | +function Base.getindex(a::SquareEyeKronecker, I::Vararg{CartesianProduct{Colon},2}) |
| 1009 | + return a.a ⊗ a.b[I[1].b, I[2].b] |
| 1010 | +end |
| 1011 | +function Base.getindex(a::KroneckerSquareEye, I::Vararg{CartesianProduct{<:Any,Colon},2}) |
| 1012 | + return a.a[I[1].a, I[2].a] ⊗ a.b |
| 1013 | +end |
| 1014 | +function Base.getindex(a::SquareEyeSquareEye, I::Vararg{CartesianProduct{Colon,Colon},2}) |
| 1015 | + return a |
| 1016 | +end |
| 1017 | + |
| 1018 | +using FillArrays: OnesVector |
| 1019 | +const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} |
| 1020 | +const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} |
| 1021 | +const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} |
| 1022 | + |
| 1023 | +function MatrixAlgebraKit.findtruncated( |
| 1024 | + values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy |
| 1025 | +) |
| 1026 | + I = findtruncated(Vector(values), strategy.strategy) |
| 1027 | + prods = collect(only(axes(values)).product)[I] |
| 1028 | + I_data = unique(map(x -> x.a, prods)) |
| 1029 | + # Drop truncations that occur within the identity. |
| 1030 | + I_data = filter(I_data) do i |
| 1031 | + return count(x -> x.a == i, prods) == length(values.a) |
| 1032 | + end |
| 1033 | + return (:) × I_data |
| 1034 | +end |
| 1035 | +function MatrixAlgebraKit.findtruncated( |
| 1036 | + values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy |
| 1037 | +) |
| 1038 | + I = findtruncated(Vector(values), strategy.strategy) |
| 1039 | + prods = collect(only(axes(values)).product)[I] |
| 1040 | + I_data = unique(map(x -> x.b, prods)) |
| 1041 | + # Drop truncations that occur within the identity. |
| 1042 | + I_data = filter(I_data) do i |
| 1043 | + return count(x -> x.b == i, prods) == length(values.b) |
| 1044 | + end |
| 1045 | + return I_data × (:) |
| 1046 | +end |
| 1047 | +function MatrixAlgebraKit.findtruncated( |
| 1048 | + values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy |
| 1049 | +) |
| 1050 | + return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) |
| 1051 | +end |
| 1052 | + |
| 1053 | +for f in [:eig_trunc!, :eigh_trunc!] |
| 1054 | + @eval begin |
| 1055 | + function MatrixAlgebraKit.truncate!( |
| 1056 | + ::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy |
| 1057 | + ) |
| 1058 | + return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) |
| 1059 | + end |
| 1060 | + function MatrixAlgebraKit.truncate!( |
| 1061 | + ::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy |
| 1062 | + ) |
| 1063 | + I = findtruncated(diagview(D), strategy) |
| 1064 | + return (D[I, I], V[(:) × (:), I]) |
| 1065 | + end |
| 1066 | + end |
| 1067 | +end |
| 1068 | + |
| 1069 | +function MatrixAlgebraKit.truncate!( |
| 1070 | + f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy |
| 1071 | +) |
| 1072 | + return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) |
| 1073 | +end |
| 1074 | +function MatrixAlgebraKit.truncate!( |
| 1075 | + ::typeof(svd_trunc!), |
| 1076 | + (U, S, Vᴴ)::NTuple{3,KroneckerMatrix}, |
| 1077 | + strategy::KroneckerTruncationStrategy, |
| 1078 | +) |
| 1079 | + I = findtruncated(diagview(S), strategy) |
| 1080 | + return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) |
| 1081 | +end |
| 1082 | + |
| 1083 | +for f in MATRIX_FUNCTIONS |
| 1084 | + @eval begin |
| 1085 | + function Base.$f(a::SquareEyeKronecker) |
| 1086 | + return a.a ⊗ $f(a.b) |
| 1087 | + end |
| 1088 | + function Base.$f(a::KroneckerSquareEye) |
| 1089 | + return $f(a.a) ⊗ a.b |
| 1090 | + end |
| 1091 | + function Base.$f(a::SquareEyeSquareEye) |
| 1092 | + return throw(ArgumentError("`$($f)` on `Eye ⊗ Eye` is not supported.")) |
| 1093 | + end |
| 1094 | + end |
| 1095 | +end |
| 1096 | + |
| 1097 | +function LinearAlgebra.pinv(a::SquareEyeKronecker; kwargs...) |
| 1098 | + return a.a ⊗ pinv(a.b; kwargs...) |
| 1099 | +end |
| 1100 | +function LinearAlgebra.pinv(a::KroneckerSquareEye; kwargs...) |
| 1101 | + return pinv(a.a; kwargs...) ⊗ a.b |
| 1102 | +end |
| 1103 | +function LinearAlgebra.pinv(a::SquareEyeSquareEye; kwargs...) |
| 1104 | + return a |
| 1105 | +end |
| 1106 | + |
944 | 1107 | end
|
0 commit comments