|
158 | 158 |
|
159 | 159 | arguments(a::KroneckerArray) = (a.a, a.b)
|
160 | 160 | arguments(a::KroneckerArray, n::Int) = arguments(a)[n]
|
| 161 | +argument_types(a::KroneckerArray) = argument_types(typeof(a)) |
| 162 | +argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A,B}}) where {A,B} = (A, B) |
161 | 163 |
|
162 | 164 | function Base.print_array(io::IO, a::KroneckerArray)
|
163 | 165 | Base.print_array(io, a.a)
|
@@ -234,6 +236,62 @@ function Base.:*(a::KroneckerArray, b::Number)
|
234 | 236 | return a.a ⊗ (a.b * b)
|
235 | 237 | end
|
236 | 238 |
|
| 239 | +function Base.:-(a::KroneckerArray) |
| 240 | + return (-a.a) ⊗ a.b |
| 241 | +end |
| 242 | +for op in (:+, :-) |
| 243 | + @eval begin |
| 244 | + function Base.$op(a::KroneckerArray, b::KroneckerArray) |
| 245 | + if a.b == b.b |
| 246 | + return $op(a.a, b.a) ⊗ a.b |
| 247 | + elseif a.a == b.a |
| 248 | + return a.a ⊗ $op(a.b, b.b) |
| 249 | + end |
| 250 | + return throw( |
| 251 | + ArgumentError( |
| 252 | + "KroneckerArray addition is only supported when the first or secord arguments match.", |
| 253 | + ), |
| 254 | + ) |
| 255 | + end |
| 256 | + end |
| 257 | +end |
| 258 | + |
| 259 | +function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray) |
| 260 | + dest.a .= a.a |
| 261 | + dest.b .= a.b |
| 262 | + return dest |
| 263 | +end |
| 264 | +function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray) |
| 265 | + if a.b == b.b |
| 266 | + map!(+, dest.a, a.a, b.a) |
| 267 | + dest.b .= a.b |
| 268 | + elseif a.a == b.a |
| 269 | + dest.a .= a.a |
| 270 | + map!(+, dest.b, a.b, b.b) |
| 271 | + else |
| 272 | + throw( |
| 273 | + ArgumentError( |
| 274 | + "KroneckerArray addition is only supported when the first or second arguments match.", |
| 275 | + ), |
| 276 | + ) |
| 277 | + end |
| 278 | + return dest |
| 279 | +end |
| 280 | +function Base.map!( |
| 281 | + f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray |
| 282 | +) |
| 283 | + dest.a .= f.f.(f.x, a.a) |
| 284 | + dest.b .= a.b |
| 285 | + return dest |
| 286 | +end |
| 287 | +function Base.map!( |
| 288 | + f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray |
| 289 | +) |
| 290 | + dest.a .= a.a |
| 291 | + dest.b .= f.f.(a.b, f.x) |
| 292 | + return dest |
| 293 | +end |
| 294 | + |
237 | 295 | using LinearAlgebra:
|
238 | 296 | LinearAlgebra,
|
239 | 297 | Diagonal,
|
@@ -346,67 +404,138 @@ function LinearAlgebra.lq(a::KroneckerArray)
|
346 | 404 | return KroneckerLQ(Fa.L ⊗ Fb.L, Fa.Q ⊗ Fb.Q)
|
347 | 405 | end
|
348 | 406 |
|
349 |
| -function Base.:-(a::KroneckerArray) |
| 407 | +using DerivableInterfaces: DerivableInterfaces, zero! |
| 408 | +function DerivableInterfaces.zero!(a::KroneckerArray) |
| 409 | + zero!(a.a) |
| 410 | + zero!(a.b) |
| 411 | + return a |
| 412 | +end |
| 413 | + |
| 414 | +using FillArrays: Eye |
| 415 | +const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B} |
| 416 | +const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B} |
| 417 | +const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B} |
| 418 | + |
| 419 | +function Base.:*(a::Number, b::EyeKronecker) |
| 420 | + return b.a ⊗ (a * b.b) |
| 421 | +end |
| 422 | +function Base.:*(a::Number, b::KroneckerEye) |
| 423 | + return (a * b.a) ⊗ b.b |
| 424 | +end |
| 425 | +function Base.:*(a::Number, b::EyeEye) |
| 426 | + return (a * b.a) ⊗ b.b |
| 427 | +end |
| 428 | +function Base.:*(a::EyeKronecker, b::Number) |
| 429 | + return a.a ⊗ (a.b * b) |
| 430 | +end |
| 431 | +function Base.:*(a::KroneckerEye, b::Number) |
| 432 | + return (a.a * b) ⊗ a.b |
| 433 | +end |
| 434 | +function Base.:*(a::EyeEye, b::Number) |
| 435 | + return a.a ⊗ (a.b * b) |
| 436 | +end |
| 437 | + |
| 438 | +function Base.:-(a::EyeKronecker) |
| 439 | + return a.a ⊗ (-a.b) |
| 440 | +end |
| 441 | +function Base.:-(a::KroneckerEye) |
| 442 | + return (-a.a) ⊗ a.b |
| 443 | +end |
| 444 | +function Base.:-(a::EyeEye) |
350 | 445 | return (-a.a) ⊗ a.b
|
351 | 446 | end
|
352 | 447 | for op in (:+, :-)
|
353 | 448 | @eval begin
|
354 |
| - function Base.$op(a::KroneckerArray, b::KroneckerArray) |
355 |
| - if a.b == b.b |
356 |
| - return $op(a.a, b.a) ⊗ a.b |
357 |
| - elseif a.a == b.a |
358 |
| - return a.a ⊗ $op(a.b, b.b) |
| 449 | + function Base.$op(a::EyeKronecker, b::EyeKronecker) |
| 450 | + if a.a ≠ b.a |
| 451 | + return throw( |
| 452 | + ArgumentError( |
| 453 | + "KroneckerArray addition is only supported when the first or secord arguments match.", |
| 454 | + ), |
| 455 | + ) |
359 | 456 | end
|
360 |
| - return throw( |
361 |
| - ArgumentError( |
362 |
| - "KroneckerArray addition is only supported when the first or secord arguments match.", |
363 |
| - ), |
364 |
| - ) |
| 457 | + return a.a ⊗ $op(a.b, b.b) |
| 458 | + end |
| 459 | + function Base.$op(a::KroneckerEye, b::KroneckerEye) |
| 460 | + if a.b ≠ b.b |
| 461 | + return throw( |
| 462 | + ArgumentError( |
| 463 | + "KroneckerArray addition is only supported when the first or secord arguments match.", |
| 464 | + ), |
| 465 | + ) |
| 466 | + end |
| 467 | + return $op(a.a, b.a) ⊗ a.b |
| 468 | + end |
| 469 | + function Base.$op(a::EyeEye, b::EyeEye) |
| 470 | + if a.b ≠ b.b |
| 471 | + return throw( |
| 472 | + ArgumentError( |
| 473 | + "KroneckerArray addition is only supported when the first or secord arguments match.", |
| 474 | + ), |
| 475 | + ) |
| 476 | + end |
| 477 | + return $op(a.a, b.a) ⊗ a.b |
365 | 478 | end
|
366 | 479 | end
|
367 | 480 | end
|
368 | 481 |
|
369 |
| -function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray) |
370 |
| - dest.a .= a.a |
| 482 | +function Base.map!(::typeof(identity), dest::EyeKronecker, a::EyeKronecker) |
371 | 483 | dest.b .= a.b
|
372 | 484 | return dest
|
373 | 485 | end
|
374 |
| -function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray) |
375 |
| - if a.b == b.b |
376 |
| - map!(+, dest.a, a.a, b.a) |
377 |
| - dest.b .= a.b |
378 |
| - elseif a.a == b.a |
379 |
| - dest.a .= a.a |
380 |
| - map!(+, dest.b, a.b, b.b) |
381 |
| - else |
| 486 | +function Base.map!(::typeof(identity), dest::KroneckerEye, a::KroneckerEye) |
| 487 | + dest.a .= a.a |
| 488 | + return dest |
| 489 | +end |
| 490 | +function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye) |
| 491 | + return error("Can't write in-place.") |
| 492 | +end |
| 493 | +function Base.map!(f::typeof(+), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker) |
| 494 | + if dest.a ≠ a.a ≠ b.a |
382 | 495 | throw(
|
383 | 496 | ArgumentError(
|
384 | 497 | "KroneckerArray addition is only supported when the first or second arguments match.",
|
385 | 498 | ),
|
386 | 499 | )
|
387 | 500 | end
|
| 501 | + map!(f, dest.b, a.b, b.b) |
388 | 502 | return dest
|
389 | 503 | end
|
390 |
| -function Base.map!( |
391 |
| - f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray |
392 |
| -) |
393 |
| - dest.a .= f.x .* a.a |
394 |
| - dest.b .= a.b |
| 504 | +function Base.map!(f::typeof(+), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye) |
| 505 | + if dest.b ≠ a.b ≠ b.b |
| 506 | + throw( |
| 507 | + ArgumentError( |
| 508 | + "KroneckerArray addition is only supported when the first or second arguments match.", |
| 509 | + ), |
| 510 | + ) |
| 511 | + end |
| 512 | + map!(f, dest.a, a.a, b.a) |
395 | 513 | return dest
|
396 | 514 | end
|
397 |
| -function Base.map!( |
398 |
| - f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray |
399 |
| -) |
400 |
| - dest.a .= a.a |
401 |
| - dest.b .= a.b .* f.x |
| 515 | +function Base.map!(f::typeof(+), dest::EyeEye, a::EyeEye, b::EyeEye) |
| 516 | + return error("Can't write in-place.") |
| 517 | +end |
| 518 | +function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) |
| 519 | + dest.b .= f.f.(f.x, a.b) |
402 | 520 | return dest
|
403 | 521 | end
|
404 |
| - |
405 |
| -using DerivableInterfaces: DerivableInterfaces, zero! |
406 |
| -function DerivableInterfaces.zero!(a::KroneckerArray) |
407 |
| - zero!(a.a) |
408 |
| - zero!(a.b) |
409 |
| - return a |
| 522 | +function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) |
| 523 | + dest.a .= f.f.(f.x, a.a) |
| 524 | + return dest |
| 525 | +end |
| 526 | +function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) |
| 527 | + return error("Can't write in-place.") |
| 528 | +end |
| 529 | +function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) |
| 530 | + dest.b .= f.f.(a.b, f.x) |
| 531 | + return dest |
| 532 | +end |
| 533 | +function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) |
| 534 | + dest.a .= f.f.(a.a, f.x) |
| 535 | + return dest |
| 536 | +end |
| 537 | +function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) |
| 538 | + return error("Can't write in-place.") |
410 | 539 | end
|
411 | 540 |
|
412 | 541 | using MatrixAlgebraKit:
|
@@ -447,15 +576,61 @@ struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
|
447 | 576 | b::B
|
448 | 577 | end
|
449 | 578 |
|
| 579 | +using MatrixAlgebraKit: |
| 580 | + copy_input, |
| 581 | + eig_full, |
| 582 | + eigh_full, |
| 583 | + qr_compact, |
| 584 | + qr_full, |
| 585 | + left_polar, |
| 586 | + lq_compact, |
| 587 | + lq_full, |
| 588 | + right_polar, |
| 589 | + svd_compact, |
| 590 | + svd_full |
| 591 | + |
| 592 | +for f in [ |
| 593 | + :eig_full, |
| 594 | + :eigh_full, |
| 595 | + :qr_compact, |
| 596 | + :qr_full, |
| 597 | + :left_polar, |
| 598 | + :lq_compact, |
| 599 | + :lq_full, |
| 600 | + :right_polar, |
| 601 | + :svd_compact, |
| 602 | + :svd_full, |
| 603 | +] |
| 604 | + @eval begin |
| 605 | + function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix) |
| 606 | + return copy_input($f, a.a) ⊗ copy_input($f, a.b) |
| 607 | + end |
| 608 | + end |
| 609 | +end |
| 610 | + |
450 | 611 | for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
|
451 | 612 | ff = Symbol("default_", f, "_algorithm")
|
452 | 613 | @eval begin
|
453 |
| - function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...) |
454 |
| - return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...)) |
| 614 | + function MatrixAlgebraKit.$ff(A::Type{<:KroneckerMatrix}; kwargs...) |
| 615 | + A1, A2 = argument_types(A) |
| 616 | + return KroneckerAlgorithm($ff(A1; kwargs...), $ff(A2; kwargs...)) |
455 | 617 | end
|
456 | 618 | end
|
457 | 619 | end
|
458 | 620 |
|
| 621 | +# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. |
| 622 | +function MatrixAlgebraKit.default_algorithm( |
| 623 | + ::typeof(qr_compact!), A::Type{<:KroneckerMatrix}; kwargs... |
| 624 | +) |
| 625 | + return default_qr_algorithm(A; kwargs...) |
| 626 | +end |
| 627 | +# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. |
| 628 | +function MatrixAlgebraKit.default_algorithm( |
| 629 | + ::typeof(qr_full!), A::Type{<:KroneckerMatrix}; kwargs... |
| 630 | +) |
| 631 | + return default_qr_algorithm(A; kwargs...) |
| 632 | +end |
| 633 | + |
459 | 634 | for f in (
|
460 | 635 | :eig_full!,
|
461 | 636 | :eigh_full!,
|
|
0 commit comments