Skip to content

Commit 8ec022f

Browse files
committed
Special cases for FillArrays.Eye
1 parent d3de597 commit 8ec022f

File tree

4 files changed

+286
-39
lines changed

4 files changed

+286
-39
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
8+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
89
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1112

1213
[compat]
1314
DerivableInterfaces = "0.4.5"
15+
FillArrays = "1.13.0"
1416
GPUArraysCore = "0.2.0"
1517
LinearAlgebra = "1.10"
1618
MatrixAlgebraKit = "0.2.0"

src/KroneckerArrays.jl

Lines changed: 267 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,62 @@ function Base.:*(a::KroneckerArray, b::Number)
234234
return a.a (a.b * b)
235235
end
236236

237+
function Base.:-(a::KroneckerArray)
238+
return (-a.a) a.b
239+
end
240+
for op in (:+, :-)
241+
@eval begin
242+
function Base.$op(a::KroneckerArray, b::KroneckerArray)
243+
if a.b == b.b
244+
return $op(a.a, b.a) a.b
245+
elseif a.a == b.a
246+
return a.a $op(a.b, b.b)
247+
end
248+
return throw(
249+
ArgumentError(
250+
"KroneckerArray addition is only supported when the first or secord arguments match.",
251+
),
252+
)
253+
end
254+
end
255+
end
256+
257+
function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
258+
dest.a .= a.a
259+
dest.b .= a.b
260+
return dest
261+
end
262+
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
263+
if a.b == b.b
264+
map!(+, dest.a, a.a, b.a)
265+
dest.b .= a.b
266+
elseif a.a == b.a
267+
dest.a .= a.a
268+
map!(+, dest.b, a.b, b.b)
269+
else
270+
throw(
271+
ArgumentError(
272+
"KroneckerArray addition is only supported when the first or second arguments match.",
273+
),
274+
)
275+
end
276+
return dest
277+
end
278+
function Base.map!(
279+
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
280+
)
281+
dest.a .= f.f.(f.x, a.a)
282+
dest.b .= a.b
283+
return dest
284+
end
285+
function Base.map!(
286+
f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
287+
)
288+
dest.a .= a.a
289+
dest.b .= f.f.(a.b, f.x)
290+
return dest
291+
end
292+
237293
using LinearAlgebra:
238294
LinearAlgebra,
239295
Diagonal,
@@ -346,67 +402,138 @@ function LinearAlgebra.lq(a::KroneckerArray)
346402
return KroneckerLQ(Fa.L Fb.L, Fa.Q Fb.Q)
347403
end
348404

349-
function Base.:-(a::KroneckerArray)
405+
using DerivableInterfaces: DerivableInterfaces, zero!
406+
function DerivableInterfaces.zero!(a::KroneckerArray)
407+
zero!(a.a)
408+
zero!(a.b)
409+
return a
410+
end
411+
412+
using FillArrays: Eye
413+
const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
414+
const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
415+
const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
416+
417+
function Base.:*(a::Number, b::EyeKronecker)
418+
return b.a (a * b.b)
419+
end
420+
function Base.:*(a::Number, b::KroneckerEye)
421+
return (a * b.a) b.b
422+
end
423+
function Base.:*(a::Number, b::EyeEye)
424+
return (a * b.a) b.b
425+
end
426+
function Base.:*(a::EyeKronecker, b::Number)
427+
return a.a (a.b * b)
428+
end
429+
function Base.:*(a::KroneckerEye, b::Number)
430+
return (a.a * b) a.b
431+
end
432+
function Base.:*(a::EyeEye, b::Number)
433+
return a.a (a.b * b)
434+
end
435+
436+
function Base.:-(a::EyeKronecker)
437+
return a.a (-a.b)
438+
end
439+
function Base.:-(a::KroneckerEye)
440+
return (-a.a) a.b
441+
end
442+
function Base.:-(a::EyeEye)
350443
return (-a.a) a.b
351444
end
352445
for op in (:+, :-)
353446
@eval begin
354-
function Base.$op(a::KroneckerArray, b::KroneckerArray)
355-
if a.b == b.b
356-
return $op(a.a, b.a) a.b
357-
elseif a.a == b.a
358-
return a.a $op(a.b, b.b)
447+
function Base.$op(a::EyeKronecker, b::EyeKronecker)
448+
if a.a b.a
449+
return throw(
450+
ArgumentError(
451+
"KroneckerArray addition is only supported when the first or secord arguments match.",
452+
),
453+
)
359454
end
360-
return throw(
361-
ArgumentError(
362-
"KroneckerArray addition is only supported when the first or secord arguments match.",
363-
),
364-
)
455+
return a.a $op(a.b, b.b)
456+
end
457+
function Base.$op(a::KroneckerEye, b::KroneckerEye)
458+
if a.b b.b
459+
return throw(
460+
ArgumentError(
461+
"KroneckerArray addition is only supported when the first or secord arguments match.",
462+
),
463+
)
464+
end
465+
return $op(a.a, b.a) a.b
466+
end
467+
function Base.$op(a::EyeEye, b::EyeEye)
468+
if a.b b.b
469+
return throw(
470+
ArgumentError(
471+
"KroneckerArray addition is only supported when the first or secord arguments match.",
472+
),
473+
)
474+
end
475+
return $op(a.a, b.a) a.b
365476
end
366477
end
367478
end
368479

369-
function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
370-
dest.a .= a.a
480+
function Base.map!(::typeof(identity), dest::EyeKronecker, a::EyeKronecker)
371481
dest.b .= a.b
372482
return dest
373483
end
374-
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
375-
if a.b == b.b
376-
map!(+, dest.a, a.a, b.a)
377-
dest.b .= a.b
378-
elseif a.a == b.a
379-
dest.a .= a.a
380-
map!(+, dest.b, a.b, b.b)
381-
else
484+
function Base.map!(::typeof(identity), dest::KroneckerEye, a::KroneckerEye)
485+
dest.a .= a.a
486+
return dest
487+
end
488+
function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye)
489+
return error("Can't write in-place.")
490+
end
491+
function Base.map!(f::typeof(+), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
492+
if dest.a a.a b.a
382493
throw(
383494
ArgumentError(
384495
"KroneckerArray addition is only supported when the first or second arguments match.",
385496
),
386497
)
387498
end
499+
map!(f, dest.b, a.b, b.b)
388500
return dest
389501
end
390-
function Base.map!(
391-
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
392-
)
393-
dest.a .= f.x .* a.a
394-
dest.b .= a.b
502+
function Base.map!(f::typeof(+), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
503+
if dest.b a.b b.b
504+
throw(
505+
ArgumentError(
506+
"KroneckerArray addition is only supported when the first or second arguments match.",
507+
),
508+
)
509+
end
510+
map!(f, dest.a, a.a, b.a)
395511
return dest
396512
end
397-
function Base.map!(
398-
f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
399-
)
400-
dest.a .= a.a
401-
dest.b .= a.b .* f.x
513+
function Base.map!(f::typeof(+), dest::EyeEye, a::EyeEye, b::EyeEye)
514+
return error("Can't write in-place.")
515+
end
516+
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
517+
dest.b .= f.f.(f.x, a.b)
402518
return dest
403519
end
404-
405-
using DerivableInterfaces: DerivableInterfaces, zero!
406-
function DerivableInterfaces.zero!(a::KroneckerArray)
407-
zero!(a.a)
408-
zero!(a.b)
409-
return a
520+
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye)
521+
dest.a .= f.f.(f.x, a.a)
522+
return dest
523+
end
524+
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
525+
return error("Can't write in-place.")
526+
end
527+
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
528+
dest.b .= f.f.(a.b, f.x)
529+
return dest
530+
end
531+
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye)
532+
dest.a .= f.f.(a.a, f.x)
533+
return dest
534+
end
535+
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
536+
return error("Can't write in-place.")
410537
end
411538

412539
using MatrixAlgebraKit:
@@ -447,6 +574,38 @@ struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
447574
b::B
448575
end
449576

577+
using MatrixAlgebraKit:
578+
copy_input,
579+
eig_full,
580+
eigh_full,
581+
qr_compact,
582+
qr_full,
583+
left_polar,
584+
lq_compact,
585+
lq_full,
586+
right_polar,
587+
svd_compact,
588+
svd_full
589+
590+
for f in [
591+
:eig_full,
592+
:eigh_full,
593+
:qr_compact,
594+
:qr_full,
595+
:left_polar,
596+
:lq_compact,
597+
:lq_full,
598+
:right_polar,
599+
:svd_compact,
600+
:svd_full,
601+
]
602+
@eval begin
603+
function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix)
604+
return copy_input($f, a.a) copy_input($f, a.b)
605+
end
606+
end
607+
end
608+
450609
for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
451610
ff = Symbol("default_", f, "_algorithm")
452611
@eval begin
@@ -530,4 +689,75 @@ for f in (:left_null!, :right_null!)
530689
end
531690
end
532691

692+
# Special case for `FillArrays.Eye` matrices.
693+
struct EyeAlgorithm <: AbstractAlgorithm end
694+
695+
for f in [
696+
:eig_full,
697+
:eigh_full,
698+
:qr_compact,
699+
:qr_full,
700+
:left_polar,
701+
:lq_compact,
702+
:lq_full,
703+
:right_polar,
704+
:svd_compact,
705+
:svd_full,
706+
]
707+
@eval begin
708+
MatrixAlgebraKit.copy_input(::typeof($f), a::Eye) = a
709+
end
710+
end
711+
712+
for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
713+
ff = Symbol("default_", f, "_algorithm")
714+
@eval begin
715+
function MatrixAlgebraKit.$ff(a::Eye; kwargs...)
716+
return EyeAlgorithm()
717+
end
718+
end
719+
end
720+
721+
for f in (
722+
:eig_full!,
723+
:eigh_full!,
724+
:qr_compact!,
725+
:qr_full!,
726+
:left_polar!,
727+
:lq_compact!,
728+
:lq_full!,
729+
:right_polar!,
730+
)
731+
@eval begin
732+
nfactors(::typeof($f)) = 2
733+
end
734+
end
735+
for f in (:svd_compact!, :svd_full!)
736+
@eval begin
737+
nfactors(::typeof($f)) = 3
738+
end
739+
end
740+
741+
for f in (
742+
:eig_full!,
743+
:eigh_full!,
744+
:qr_compact!,
745+
:qr_full!,
746+
:left_polar!,
747+
:lq_compact!,
748+
:lq_full!,
749+
:right_polar!,
750+
:svd_compact!,
751+
:svd_full!,
752+
)
753+
@eval begin
754+
function MatrixAlgebraKit.initialize_output(::typeof($f), a::Eye, alg::EyeAlgorithm)
755+
return ntuple(_ -> a, nfactors($f))
756+
end
757+
function MatrixAlgebraKit.$f(a::Eye, F, alg::EyeAlgorithm; kwargs...)
758+
return ntuple(_ -> a, nfactors($f))
759+
end
760+
end
761+
end
762+
533763
end

test/test_aqua.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ using Aqua: Aqua
33
using Test: @testset
44

55
@testset "Code quality (Aqua.jl)" begin
6-
Aqua.test_all(KroneckerArrays)
6+
Aqua.test_all(KroneckerArrays; piracies=false)
77
end

0 commit comments

Comments
 (0)