Skip to content

Commit e5325c1

Browse files
authored
Special cases for FillArrays.Eye, fix MatrixAlgebraKit factorizations (#6)
1 parent d3de597 commit e5325c1

File tree

5 files changed

+240
-44
lines changed

5 files changed

+240
-44
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
8+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
89
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1112

1213
[compat]
1314
DerivableInterfaces = "0.4.5"
15+
FillArrays = "1.13.0"
1416
GPUArraysCore = "0.2.0"
1517
LinearAlgebra = "1.10"
1618
MatrixAlgebraKit = "0.2.0"

src/KroneckerArrays.jl

Lines changed: 214 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ end
158158

159159
arguments(a::KroneckerArray) = (a.a, a.b)
160160
arguments(a::KroneckerArray, n::Int) = arguments(a)[n]
161+
argument_types(a::KroneckerArray) = argument_types(typeof(a))
162+
argument_types(::Type{<:KroneckerArray{<:Any,<:Any,A,B}}) where {A,B} = (A, B)
161163

162164
function Base.print_array(io::IO, a::KroneckerArray)
163165
Base.print_array(io, a.a)
@@ -234,6 +236,62 @@ function Base.:*(a::KroneckerArray, b::Number)
234236
return a.a (a.b * b)
235237
end
236238

239+
function Base.:-(a::KroneckerArray)
240+
return (-a.a) a.b
241+
end
242+
for op in (:+, :-)
243+
@eval begin
244+
function Base.$op(a::KroneckerArray, b::KroneckerArray)
245+
if a.b == b.b
246+
return $op(a.a, b.a) a.b
247+
elseif a.a == b.a
248+
return a.a $op(a.b, b.b)
249+
end
250+
return throw(
251+
ArgumentError(
252+
"KroneckerArray addition is only supported when the first or secord arguments match.",
253+
),
254+
)
255+
end
256+
end
257+
end
258+
259+
function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
260+
dest.a .= a.a
261+
dest.b .= a.b
262+
return dest
263+
end
264+
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
265+
if a.b == b.b
266+
map!(+, dest.a, a.a, b.a)
267+
dest.b .= a.b
268+
elseif a.a == b.a
269+
dest.a .= a.a
270+
map!(+, dest.b, a.b, b.b)
271+
else
272+
throw(
273+
ArgumentError(
274+
"KroneckerArray addition is only supported when the first or second arguments match.",
275+
),
276+
)
277+
end
278+
return dest
279+
end
280+
function Base.map!(
281+
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
282+
)
283+
dest.a .= f.f.(f.x, a.a)
284+
dest.b .= a.b
285+
return dest
286+
end
287+
function Base.map!(
288+
f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
289+
)
290+
dest.a .= a.a
291+
dest.b .= f.f.(a.b, f.x)
292+
return dest
293+
end
294+
237295
using LinearAlgebra:
238296
LinearAlgebra,
239297
Diagonal,
@@ -346,67 +404,138 @@ function LinearAlgebra.lq(a::KroneckerArray)
346404
return KroneckerLQ(Fa.L Fb.L, Fa.Q Fb.Q)
347405
end
348406

349-
function Base.:-(a::KroneckerArray)
407+
using DerivableInterfaces: DerivableInterfaces, zero!
408+
function DerivableInterfaces.zero!(a::KroneckerArray)
409+
zero!(a.a)
410+
zero!(a.b)
411+
return a
412+
end
413+
414+
using FillArrays: Eye
415+
const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
416+
const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
417+
const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
418+
419+
function Base.:*(a::Number, b::EyeKronecker)
420+
return b.a (a * b.b)
421+
end
422+
function Base.:*(a::Number, b::KroneckerEye)
423+
return (a * b.a) b.b
424+
end
425+
function Base.:*(a::Number, b::EyeEye)
426+
return (a * b.a) b.b
427+
end
428+
function Base.:*(a::EyeKronecker, b::Number)
429+
return a.a (a.b * b)
430+
end
431+
function Base.:*(a::KroneckerEye, b::Number)
432+
return (a.a * b) a.b
433+
end
434+
function Base.:*(a::EyeEye, b::Number)
435+
return a.a (a.b * b)
436+
end
437+
438+
function Base.:-(a::EyeKronecker)
439+
return a.a (-a.b)
440+
end
441+
function Base.:-(a::KroneckerEye)
442+
return (-a.a) a.b
443+
end
444+
function Base.:-(a::EyeEye)
350445
return (-a.a) a.b
351446
end
352447
for op in (:+, :-)
353448
@eval begin
354-
function Base.$op(a::KroneckerArray, b::KroneckerArray)
355-
if a.b == b.b
356-
return $op(a.a, b.a) a.b
357-
elseif a.a == b.a
358-
return a.a $op(a.b, b.b)
449+
function Base.$op(a::EyeKronecker, b::EyeKronecker)
450+
if a.a b.a
451+
return throw(
452+
ArgumentError(
453+
"KroneckerArray addition is only supported when the first or secord arguments match.",
454+
),
455+
)
359456
end
360-
return throw(
361-
ArgumentError(
362-
"KroneckerArray addition is only supported when the first or secord arguments match.",
363-
),
364-
)
457+
return a.a $op(a.b, b.b)
458+
end
459+
function Base.$op(a::KroneckerEye, b::KroneckerEye)
460+
if a.b b.b
461+
return throw(
462+
ArgumentError(
463+
"KroneckerArray addition is only supported when the first or secord arguments match.",
464+
),
465+
)
466+
end
467+
return $op(a.a, b.a) a.b
468+
end
469+
function Base.$op(a::EyeEye, b::EyeEye)
470+
if a.b b.b
471+
return throw(
472+
ArgumentError(
473+
"KroneckerArray addition is only supported when the first or secord arguments match.",
474+
),
475+
)
476+
end
477+
return $op(a.a, b.a) a.b
365478
end
366479
end
367480
end
368481

369-
function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
370-
dest.a .= a.a
482+
function Base.map!(::typeof(identity), dest::EyeKronecker, a::EyeKronecker)
371483
dest.b .= a.b
372484
return dest
373485
end
374-
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
375-
if a.b == b.b
376-
map!(+, dest.a, a.a, b.a)
377-
dest.b .= a.b
378-
elseif a.a == b.a
379-
dest.a .= a.a
380-
map!(+, dest.b, a.b, b.b)
381-
else
486+
function Base.map!(::typeof(identity), dest::KroneckerEye, a::KroneckerEye)
487+
dest.a .= a.a
488+
return dest
489+
end
490+
function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye)
491+
return error("Can't write in-place.")
492+
end
493+
function Base.map!(f::typeof(+), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
494+
if dest.a a.a b.a
382495
throw(
383496
ArgumentError(
384497
"KroneckerArray addition is only supported when the first or second arguments match.",
385498
),
386499
)
387500
end
501+
map!(f, dest.b, a.b, b.b)
388502
return dest
389503
end
390-
function Base.map!(
391-
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
392-
)
393-
dest.a .= f.x .* a.a
394-
dest.b .= a.b
504+
function Base.map!(f::typeof(+), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
505+
if dest.b a.b b.b
506+
throw(
507+
ArgumentError(
508+
"KroneckerArray addition is only supported when the first or second arguments match.",
509+
),
510+
)
511+
end
512+
map!(f, dest.a, a.a, b.a)
395513
return dest
396514
end
397-
function Base.map!(
398-
f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
399-
)
400-
dest.a .= a.a
401-
dest.b .= a.b .* f.x
515+
function Base.map!(f::typeof(+), dest::EyeEye, a::EyeEye, b::EyeEye)
516+
return error("Can't write in-place.")
517+
end
518+
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
519+
dest.b .= f.f.(f.x, a.b)
402520
return dest
403521
end
404-
405-
using DerivableInterfaces: DerivableInterfaces, zero!
406-
function DerivableInterfaces.zero!(a::KroneckerArray)
407-
zero!(a.a)
408-
zero!(a.b)
409-
return a
522+
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye)
523+
dest.a .= f.f.(f.x, a.a)
524+
return dest
525+
end
526+
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
527+
return error("Can't write in-place.")
528+
end
529+
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
530+
dest.b .= f.f.(a.b, f.x)
531+
return dest
532+
end
533+
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye)
534+
dest.a .= f.f.(a.a, f.x)
535+
return dest
536+
end
537+
function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye)
538+
return error("Can't write in-place.")
410539
end
411540

412541
using MatrixAlgebraKit:
@@ -447,15 +576,61 @@ struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
447576
b::B
448577
end
449578

579+
using MatrixAlgebraKit:
580+
copy_input,
581+
eig_full,
582+
eigh_full,
583+
qr_compact,
584+
qr_full,
585+
left_polar,
586+
lq_compact,
587+
lq_full,
588+
right_polar,
589+
svd_compact,
590+
svd_full
591+
592+
for f in [
593+
:eig_full,
594+
:eigh_full,
595+
:qr_compact,
596+
:qr_full,
597+
:left_polar,
598+
:lq_compact,
599+
:lq_full,
600+
:right_polar,
601+
:svd_compact,
602+
:svd_full,
603+
]
604+
@eval begin
605+
function MatrixAlgebraKit.copy_input(::typeof($f), a::KroneckerMatrix)
606+
return copy_input($f, a.a) copy_input($f, a.b)
607+
end
608+
end
609+
end
610+
450611
for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
451612
ff = Symbol("default_", f, "_algorithm")
452613
@eval begin
453-
function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...)
454-
return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...))
614+
function MatrixAlgebraKit.$ff(A::Type{<:KroneckerMatrix}; kwargs...)
615+
A1, A2 = argument_types(A)
616+
return KroneckerAlgorithm($ff(A1; kwargs...), $ff(A2; kwargs...))
455617
end
456618
end
457619
end
458620

621+
# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
622+
function MatrixAlgebraKit.default_algorithm(
623+
::typeof(qr_compact!), A::Type{<:KroneckerMatrix}; kwargs...
624+
)
625+
return default_qr_algorithm(A; kwargs...)
626+
end
627+
# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
628+
function MatrixAlgebraKit.default_algorithm(
629+
::typeof(qr_full!), A::Type{<:KroneckerMatrix}; kwargs...
630+
)
631+
return default_qr_algorithm(A; kwargs...)
632+
end
633+
459634
for f in (
460635
:eig_full!,
461636
:eigh_full!,

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
34
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
45
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
56
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
@@ -9,6 +10,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
910

1011
[compat]
1112
Aqua = "0.8"
13+
FillArrays = "1"
1214
KroneckerArrays = "0.1"
1315
LinearAlgebra = "1.10"
1416
MatrixAlgebraKit = "0.2"

test/test_basics.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using FillArrays: Eye
12
using KroneckerArrays: KroneckerArrays, , ×, diagonal, kron_nd
23
using LinearAlgebra: Diagonal, I, eigen, eigvals, lq, qr, svd, svdvals, tr
34
using Test: @test, @testset
@@ -66,3 +67,17 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64)
6667
@test collect(Q * R) collect(a)
6768
@test collect(Q'Q) I
6869
end
70+
71+
@testset "FillArrays.Eye" begin
72+
a = Eye(2) randn(3, 3)
73+
@test size(a) == (6, 6)
74+
@test a + a == Eye(2) (2a.b)
75+
@test 2a == Eye(2) (2a.b)
76+
@test a * a == Eye(2) (a.b * a.b)
77+
78+
a = randn(3, 3) Eye(2)
79+
@test size(a) == (6, 6)
80+
@test a + a == (2a.a) Eye(2)
81+
@test 2a == (2a.a) Eye(2)
82+
@test a * a == (a.a * a.a) Eye(2)
83+
end

0 commit comments

Comments
 (0)