|
1 | 1 | module KroneckerArrays
|
2 | 2 |
|
| 3 | +using GPUArraysCore: GPUArraysCore |
| 4 | + |
3 | 5 | export ⊗, ×
|
4 | 6 |
|
5 | 7 | struct CartesianProduct{A,B}
|
|
28 | 30 | Base.first(r::CartesianProductUnitRange) = first(r.range)
|
29 | 31 | Base.last(r::CartesianProductUnitRange) = last(r.range)
|
30 | 32 |
|
| 33 | +function Base.axes(r::CartesianProductUnitRange) |
| 34 | + return (CartesianProductUnitRange(r.product, only(axes(r.range))),) |
| 35 | +end |
| 36 | + |
| 37 | +using Base.Broadcast: DefaultArrayStyle |
| 38 | +for f in (:+, :-) |
| 39 | + @eval begin |
| 40 | + function Broadcast.broadcasted( |
| 41 | + ::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer |
| 42 | + ) |
| 43 | + return CartesianProductUnitRange(r.product, $f.(r.range, x)) |
| 44 | + end |
| 45 | + function Broadcast.broadcasted( |
| 46 | + ::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange |
| 47 | + ) |
| 48 | + return CartesianProductUnitRange(r.product, $f.(x, r.range)) |
| 49 | + end |
| 50 | + end |
| 51 | +end |
| 52 | + |
31 | 53 | struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
32 | 54 | a::A
|
33 | 55 | b::B
|
|
44 | 66 | const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
|
45 | 67 | const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}
|
46 | 68 |
|
| 69 | +function Base.copy(a::KroneckerArray) |
| 70 | + return copy(a.a) ⊗ copy(a.b) |
| 71 | +end |
| 72 | +function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) |
| 73 | + copyto!(dest.a, src.a) |
| 74 | + copyto!(dest.b, src.b) |
| 75 | + return dest |
| 76 | +end |
| 77 | + |
47 | 78 | function Base.similar(
|
48 | 79 | a::AbstractArray,
|
49 | 80 | elt::Type,
|
@@ -73,9 +104,21 @@ function Base.similar(
|
73 | 104 | return similar(arrayt, map(ax -> ax.product.a, axs)) ⊗
|
74 | 105 | similar(arrayt, map(ax -> ax.product.b, axs))
|
75 | 106 | end
|
| 107 | +function Base.similar( |
| 108 | + arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, |
| 109 | + axs::Tuple{ |
| 110 | + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} |
| 111 | + }, |
| 112 | +) where {A,B} |
| 113 | + return similar(A, map(ax -> ax.product.a, axs)) ⊗ similar(B, map(ax -> ax.product.b, axs)) |
| 114 | +end |
76 | 115 |
|
77 | 116 | Base.collect(a::KroneckerArray) = kron(a.a, a.b)
|
78 | 117 |
|
| 118 | +function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} |
| 119 | + return convert(Array{T,N}, collect(a)) |
| 120 | +end |
| 121 | + |
79 | 122 | Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a))
|
80 | 123 |
|
81 | 124 | function Base.axes(a::KroneckerArray)
|
@@ -107,12 +150,23 @@ end
|
107 | 150 | ⊗(a::Number, b::AbstractVecOrMat) = a * b
|
108 | 151 | ⊗(a::AbstractVecOrMat, b::Number) = a * b
|
109 | 152 |
|
110 |
| -function Base.getindex(::KroneckerArray, ::Int) |
111 |
| - return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported.")) |
| 153 | +function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) |
| 154 | + GPUArraysCore.assertscalar("getindex") |
| 155 | + # Code logic from Kronecker.jl: |
| 156 | + # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105 |
| 157 | + k, l = size(a.b) |
| 158 | + return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1] |
| 159 | +end |
| 160 | +function Base.getindex(a::KroneckerMatrix, i::Integer) |
| 161 | + return a[CartesianIndices(a)[i]] |
112 | 162 | end
|
113 |
| -function Base.getindex(::KroneckerArray{<:Any,N}, ::Vararg{Int,N}) where {N} |
114 |
| - return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported.")) |
| 163 | + |
| 164 | +function Base.getindex(a::KroneckerVector, i::Integer) |
| 165 | + GPUArraysCore.assertscalar("getindex") |
| 166 | + k = length(a.b) |
| 167 | + return a.a[cld(i, k)] * a.b[(i - 1) % k + 1] |
115 | 168 | end
|
| 169 | + |
116 | 170 | function Base.getindex(a::KroneckerVector, i::CartesianProduct)
|
117 | 171 | return a.a[i.a] ⊗ a.b[i.b]
|
118 | 172 | end
|
|
169 | 223 | function Base.:*(a::KroneckerArray, b::KroneckerArray)
|
170 | 224 | return (a.a * b.a) ⊗ (a.b * b.b)
|
171 | 225 | end
|
172 |
| -function LinearAlgebra.mul!(c::KroneckerArray, a::KroneckerArray, b::KroneckerArray) |
| 226 | +function LinearAlgebra.mul!( |
| 227 | + c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number |
| 228 | +) |
| 229 | + iszero(β) || |
| 230 | + iszero(c) || |
| 231 | + throw( |
| 232 | + ArgumentError( |
| 233 | + "Can't multiple KroneckerArrays with nonzero β and nonzero destination." |
| 234 | + ), |
| 235 | + ) |
173 | 236 | mul!(c.a, a.a, b.a)
|
174 |
| - mul!(c.b, a.b, b.b) |
| 237 | + mul!(c.b, a.b, b.b, α, β) |
175 | 238 | return c
|
176 | 239 | end
|
177 | 240 | function LinearAlgebra.tr(a::KroneckerArray)
|
@@ -269,4 +332,168 @@ for op in (:+, :-)
|
269 | 332 | end
|
270 | 333 | end
|
271 | 334 |
|
| 335 | +function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray) |
| 336 | + dest.a .= a.a |
| 337 | + dest.b .= a.b |
| 338 | + return dest |
| 339 | +end |
| 340 | +function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray) |
| 341 | + if a.b == b.b |
| 342 | + map!(+, dest.a, a.a, b.a) |
| 343 | + dest.b .= a.b |
| 344 | + elseif a.a == b.a |
| 345 | + dest.a .= a.a |
| 346 | + map!(+, dest.b, a.b, b.b) |
| 347 | + else |
| 348 | + throw( |
| 349 | + ArgumentError( |
| 350 | + "KroneckerArray addition is only supported when the first or second arguments match.", |
| 351 | + ), |
| 352 | + ) |
| 353 | + end |
| 354 | + return dest |
| 355 | +end |
| 356 | +function Base.map!( |
| 357 | + f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray |
| 358 | +) |
| 359 | + dest.a .= f.x .* a.a |
| 360 | + dest.b .= a.b |
| 361 | + return dest |
| 362 | +end |
| 363 | +function Base.map!( |
| 364 | + f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray |
| 365 | +) |
| 366 | + dest.a .= a.a |
| 367 | + dest.b .= a.b .* f.x |
| 368 | + return dest |
| 369 | +end |
| 370 | + |
| 371 | +using DerivableInterfaces: DerivableInterfaces, zero! |
| 372 | +function DerivableInterfaces.zero!(a::KroneckerArray) |
| 373 | + zero!(a.a) |
| 374 | + zero!(a.b) |
| 375 | + return a |
| 376 | +end |
| 377 | + |
| 378 | +using MatrixAlgebraKit: |
| 379 | + MatrixAlgebraKit, |
| 380 | + AbstractAlgorithm, |
| 381 | + TruncationStrategy, |
| 382 | + default_eig_algorithm, |
| 383 | + default_eigh_algorithm, |
| 384 | + default_lq_algorithm, |
| 385 | + default_polar_algorithm, |
| 386 | + default_qr_algorithm, |
| 387 | + default_svd_algorithm, |
| 388 | + eig_full!, |
| 389 | + eig_trunc!, |
| 390 | + eig_vals!, |
| 391 | + eigh_full!, |
| 392 | + eigh_trunc!, |
| 393 | + eigh_vals!, |
| 394 | + initialize_output, |
| 395 | + left_null!, |
| 396 | + left_orth!, |
| 397 | + left_polar!, |
| 398 | + lq_compact!, |
| 399 | + lq_full!, |
| 400 | + qr_compact!, |
| 401 | + qr_full!, |
| 402 | + right_null!, |
| 403 | + right_orth!, |
| 404 | + right_polar!, |
| 405 | + svd_compact!, |
| 406 | + svd_full!, |
| 407 | + svd_trunc!, |
| 408 | + svd_vals!, |
| 409 | + truncate! |
| 410 | + |
| 411 | +struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm |
| 412 | + a::A |
| 413 | + b::B |
| 414 | +end |
| 415 | + |
| 416 | +for f in (:eig, :eigh, :lq, :qr, :polar, :svd) |
| 417 | + ff = Symbol("default_", f, "_algorithm") |
| 418 | + @eval begin |
| 419 | + function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...) |
| 420 | + return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...)) |
| 421 | + end |
| 422 | + end |
| 423 | +end |
| 424 | + |
| 425 | +for f in ( |
| 426 | + :eig_full!, |
| 427 | + :eigh_full!, |
| 428 | + :qr_compact!, |
| 429 | + :qr_full!, |
| 430 | + :left_polar!, |
| 431 | + :lq_compact!, |
| 432 | + :lq_full!, |
| 433 | + :right_polar!, |
| 434 | + :svd_compact!, |
| 435 | + :svd_full!, |
| 436 | +) |
| 437 | + @eval begin |
| 438 | + function MatrixAlgebraKit.initialize_output( |
| 439 | + ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm |
| 440 | + ) |
| 441 | + return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b) |
| 442 | + end |
| 443 | + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...) |
| 444 | + $f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...) |
| 445 | + $f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...) |
| 446 | + return F |
| 447 | + end |
| 448 | + end |
| 449 | +end |
| 450 | + |
| 451 | +for f in (:eig_vals!, :eigh_vals!, :svd_vals!) |
| 452 | + @eval begin |
| 453 | + function MatrixAlgebraKit.initialize_output( |
| 454 | + ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm |
| 455 | + ) |
| 456 | + return initialize_output($f, a.a, alg.a) ⊗ initialize_output($f, a.b, alg.b) |
| 457 | + end |
| 458 | + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) |
| 459 | + $f(a.a, F.a, alg.a) |
| 460 | + $f(a.b, F.b, alg.b) |
| 461 | + return F |
| 462 | + end |
| 463 | + end |
| 464 | +end |
| 465 | + |
| 466 | +for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!) |
| 467 | + @eval begin |
| 468 | + function MatrixAlgebraKit.truncate!( |
| 469 | + ::typeof($f), |
| 470 | + (D, V)::Tuple{KroneckerMatrix,KroneckerMatrix}, |
| 471 | + strategy::TruncationStrategy, |
| 472 | + ) |
| 473 | + return throw(MethodError(truncate!, ($f, (D, V), strategy))) |
| 474 | + end |
| 475 | + end |
| 476 | +end |
| 477 | + |
| 478 | +for f in (:left_orth!, :right_orth!) |
| 479 | + @eval begin |
| 480 | + function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) |
| 481 | + return initialize_output($f, a.a) .⊗ initialize_output($f, a.b) |
| 482 | + end |
| 483 | + end |
| 484 | +end |
| 485 | + |
| 486 | +for f in (:left_null!, :right_null!) |
| 487 | + @eval begin |
| 488 | + function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) |
| 489 | + return initialize_output($f, a.a) ⊗ initialize_output($f, a.b) |
| 490 | + end |
| 491 | + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...) |
| 492 | + $f(a.a, F.a; kwargs...) |
| 493 | + $f(a.b, F.b; kwargs...) |
| 494 | + return F |
| 495 | + end |
| 496 | + end |
| 497 | +end |
| 498 | + |
272 | 499 | end
|
0 commit comments