Skip to content

Commit c6c3436

Browse files
authored
More similar, map, broadcast (#11)
1 parent b9010bb commit c6c3436

File tree

5 files changed

+541
-118
lines changed

5 files changed

+541
-118
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.8"
4+
version = "0.1.9"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
8+
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
89
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
910
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1213

1314
[compat]
1415
DerivableInterfaces = "0.4.5"
16+
DiagonalArrays = "0.3.5"
1517
FillArrays = "1.13.0"
1618
GPUArraysCore = "0.2.0"
1719
LinearAlgebra = "1.10"

src/KroneckerArrays.jl

Lines changed: 231 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ end
250250
function Base.iszero(a::KroneckerArray)
251251
return iszero(a.a) || iszero(a.b)
252252
end
253+
function Base.isreal(a::KroneckerArray)
254+
return isreal(a.a) && isreal(a.b)
255+
end
253256
function Base.inv(a::KroneckerArray)
254257
return inv(a.a) inv(a.b)
255258
end
@@ -270,6 +273,9 @@ end
270273
function Base.:*(a::KroneckerArray, b::Number)
271274
return a.a (a.b * b)
272275
end
276+
function Base.:/(a::KroneckerArray, b::Number)
277+
return a * inv(b)
278+
end
273279

274280
function Base.:-(a::KroneckerArray)
275281
return (-a.a) a.b
@@ -291,26 +297,82 @@ for op in (:+, :-)
291297
end
292298
end
293299

300+
using Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
301+
struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
302+
function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N}
303+
return KroneckerStyle{N,a,b}()
304+
end
305+
function KroneckerStyle(a::AbstractArrayStyle{N}, b::AbstractArrayStyle{N}) where {N}
306+
return KroneckerStyle{N}(a, b)
307+
end
308+
function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
309+
return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}()
310+
end
311+
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B}
312+
return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B))
313+
end
314+
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
315+
return KroneckerStyle{N}(
316+
BroadcastStyle(style1.a, style2.a), BroadcastStyle(style1.b, style2.b)
317+
)
318+
end
319+
function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B}
320+
ax_a = map(ax -> ax.product.a, axes(bc))
321+
ax_b = map(ax -> ax.product.b, axes(bc))
322+
bc_a = Broadcasted(A, nothing, (), ax_a)
323+
bc_b = Broadcasted(B, nothing, (), ax_b)
324+
a = similar(bc_a, elt)
325+
b = similar(bc_b, elt)
326+
return a b
327+
end
328+
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle})
329+
return throw(
330+
ArgumentError(
331+
"Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
332+
),
333+
)
334+
end
335+
336+
function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...)
337+
return throw(
338+
ArgumentError(
339+
"Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
340+
),
341+
)
342+
end
343+
function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...)
344+
return throw(
345+
ArgumentError(
346+
"Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
347+
),
348+
)
349+
end
294350
function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
295351
dest.a .= a.a
296352
dest.b .= a.b
297353
return dest
298354
end
299-
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
300-
if a.b == b.b
301-
map!(+, dest.a, a.a, b.a)
302-
dest.b .= a.b
303-
elseif a.a == b.a
304-
dest.a .= a.a
305-
map!(+, dest.b, a.b, b.b)
306-
else
307-
throw(
308-
ArgumentError(
309-
"KroneckerArray addition is only supported when the first or second arguments match.",
310-
),
355+
for f in [:+, :-]
356+
@eval begin
357+
function Base.map!(
358+
::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray
311359
)
360+
if a.b == b.b
361+
map!($f, dest.a, a.a, b.a)
362+
dest.b .= a.b
363+
elseif a.a == b.a
364+
dest.a .= a.a
365+
map!($f, dest.b, a.b, b.b)
366+
else
367+
throw(
368+
ArgumentError(
369+
"KroneckerArray addition is only supported when the first or second arguments match.",
370+
),
371+
)
372+
end
373+
return dest
374+
end
312375
end
313-
return dest
314376
end
315377
function Base.map!(
316378
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
@@ -326,6 +388,16 @@ function Base.map!(
326388
dest.b .= f.f.(a.b, f.x)
327389
return dest
328390
end
391+
function Base.map!(
392+
f::Base.Fix2{typeof(/),<:Number}, dest::KroneckerArray, a::KroneckerArray
393+
)
394+
return map!(Base.Fix2(*, inv(f.x)), dest, a)
395+
end
396+
function Base.map!(::typeof(conj), dest::KroneckerArray, a::KroneckerArray)
397+
dest.a .= conj.(a.a)
398+
dest.b .= conj.(a.b)
399+
return dest
400+
end
329401

330402
using LinearAlgebra:
331403
LinearAlgebra,
@@ -343,9 +415,10 @@ using LinearAlgebra:
343415
svd,
344416
svdvals,
345417
tr
346-
diagonal(a::AbstractArray) = Diagonal(a)
347-
function diagonal(a::KroneckerArray)
348-
return Diagonal(a.a) Diagonal(a.b)
418+
419+
using DiagonalArrays: DiagonalArrays, diagonal
420+
function DiagonalArrays.diagonal(a::KroneckerArray)
421+
return diagonal(a.a) diagonal(a.b)
349422
end
350423

351424
function Base.:*(a::KroneckerArray, b::KroneckerArray)
@@ -372,6 +445,23 @@ function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
372445
return norm(a.a, p) norm(a.b, p)
373446
end
374447

448+
function Base.real(a::KroneckerArray)
449+
if iszero(imag(a.a)) || iszero(imag(a.b))
450+
return real(a.a) real(a.b)
451+
elseif iszero(real(a.a)) || iszero(real(a.b))
452+
return -imag(a.a) imag(a.b)
453+
end
454+
return real(a.a) real(a.b) - imag(a.a) imag(a.b)
455+
end
456+
function Base.imag(a::KroneckerArray)
457+
if iszero(imag(a.a)) || iszero(real(a.b))
458+
return real(a.a) imag(a.b)
459+
elseif iszero(real(a.a)) || iszero(imag(a.b))
460+
return imag(a.a) real(a.b)
461+
end
462+
return real(a.a) imag(a.b) + imag(a.a) real(a.b)
463+
end
464+
375465
using MatrixAlgebraKit: MatrixAlgebraKit, diagview
376466
function MatrixAlgebraKit.diagview(a::KroneckerMatrix)
377467
return diagview(a.a) diagview(a.b)
@@ -506,6 +596,19 @@ const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
506596
const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
507597
const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
508598

599+
using DerivableInterfaces: DerivableInterfaces, zero!
600+
function DerivableInterfaces.zero!(a::EyeKronecker)
601+
zero!(a.b)
602+
return a
603+
end
604+
function DerivableInterfaces.zero!(a::KroneckerEye)
605+
zero!(a.a)
606+
return a
607+
end
608+
function DerivableInterfaces.zero!(a::EyeEye)
609+
return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`."))
610+
end
611+
509612
function Base.:*(a::Number, b::EyeKronecker)
510613
return b.a (a * b.b)
511614
end
@@ -580,29 +683,44 @@ end
580683
function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye)
581684
return error("Can't write in-place.")
582685
end
583-
function Base.map!(f::typeof(+), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
584-
if dest.a a.a b.a
585-
throw(
586-
ArgumentError(
587-
"KroneckerArray addition is only supported when the first or second arguments match.",
588-
),
589-
)
686+
for f in [:+, :-]
687+
@eval begin
688+
function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
689+
if dest.a a.a b.a
690+
throw(
691+
ArgumentError(
692+
"KroneckerArray addition is only supported when the first or second arguments match.",
693+
),
694+
)
695+
end
696+
map!($f, dest.b, a.b, b.b)
697+
return dest
698+
end
699+
function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
700+
if dest.b a.b b.b
701+
throw(
702+
ArgumentError(
703+
"KroneckerArray addition is only supported when the first or second arguments match.",
704+
),
705+
)
706+
end
707+
map!($f, dest.a, a.a, b.a)
708+
return dest
709+
end
710+
function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye)
711+
return error("Can't write in-place.")
712+
end
590713
end
591-
map!(f, dest.b, a.b, b.b)
714+
end
715+
function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker)
716+
map!(f, dest.b, a.b)
592717
return dest
593718
end
594-
function Base.map!(f::typeof(+), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
595-
if dest.b a.b b.b
596-
throw(
597-
ArgumentError(
598-
"KroneckerArray addition is only supported when the first or second arguments match.",
599-
),
600-
)
601-
end
602-
map!(f, dest.a, a.a, b.a)
719+
function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye)
720+
map!(f, dest.a, a.a)
603721
return dest
604722
end
605-
function Base.map!(f::typeof(+), dest::EyeEye, a::EyeEye, b::EyeEye)
723+
function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye)
606724
return error("Can't write in-place.")
607725
end
608726
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
@@ -812,6 +930,74 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
812930
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
813931
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
814932

933+
# Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`.
934+
function Base.similar(
935+
a::SquareEyeKronecker,
936+
elt::Type,
937+
axs::Tuple{
938+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
939+
},
940+
)
941+
ax_a = map(ax -> ax.product.a, axs)
942+
ax_b = map(ax -> ax.product.b, axs)
943+
eye_ax_a = (only(unique(ax_a)),)
944+
return Eye{elt}(eye_ax_a) similar(a.b, elt, ax_b)
945+
end
946+
function Base.similar(
947+
a::KroneckerSquareEye,
948+
elt::Type,
949+
axs::Tuple{
950+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
951+
},
952+
)
953+
ax_a = map(ax -> ax.product.a, axs)
954+
ax_b = map(ax -> ax.product.b, axs)
955+
eye_ax_b = (only(unique(ax_b)),)
956+
return similar(a.a, elt, ax_a) Eye{elt}(eye_ax_b)
957+
end
958+
function Base.similar(
959+
a::SquareEyeSquareEye,
960+
elt::Type,
961+
axs::Tuple{
962+
CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}}
963+
},
964+
)
965+
ax_a = map(ax -> ax.product.a, axs)
966+
ax_b = map(ax -> ax.product.b, axs)
967+
eye_ax_a = (only(unique(ax_a)),)
968+
eye_ax_b = (only(unique(ax_b)),)
969+
return Eye{elt}(eye_ax_a) Eye{elt}(eye_ax_b)
970+
end
971+
972+
function Base.similar(
973+
arrayt::Type{<:SquareEyeKronecker{<:Any,<:Any,A}},
974+
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
975+
) where {A}
976+
ax_a = map(ax -> ax.product.a, axs)
977+
ax_b = map(ax -> ax.product.b, axs)
978+
eye_ax_a = (only(unique(ax_a)),)
979+
return Eye{eltype(arrayt)}(eye_ax_a) similar(A, ax_b)
980+
end
981+
function Base.similar(
982+
arrayt::Type{<:KroneckerSquareEye{<:Any,A}},
983+
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
984+
) where {A}
985+
ax_a = map(ax -> ax.product.a, axs)
986+
ax_b = map(ax -> ax.product.b, axs)
987+
eye_ax_b = (only(unique(ax_b)),)
988+
return similar(A, ax_a) Eye{eltype(arrayt)}(eye_ax_b)
989+
end
990+
function Base.similar(
991+
arrayt::Type{<:SquareEyeSquareEye}, axs::NTuple{2,CartesianProductUnitRange{<:Integer}}
992+
)
993+
elt = eltype(arrayt)
994+
ax_a = map(ax -> ax.product.a, axs)
995+
ax_b = map(ax -> ax.product.b, axs)
996+
eye_ax_a = (only(unique(ax_a)),)
997+
eye_ax_b = (only(unique(ax_b)),)
998+
return Eye{elt}(eye_ax_a) Eye{elt}(eye_ax_b)
999+
end
1000+
8151001
struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm
8161002
kwargs::KWargs
8171003
end
@@ -884,8 +1070,6 @@ for f in [:left_null!, :right_null!]
8841070
end
8851071
end
8861072
for f in [
887-
:eig_full!,
888-
:eigh_full!,
8891073
:qr_compact!,
8901074
:qr_full!,
8911075
:left_orth!,
@@ -900,10 +1084,14 @@ for f in [
9001084
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a)
9011085
end
9021086
end
1087+
_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye) = complex.((a, a))
1088+
_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye, alg) = complex.((a, a))
1089+
_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye) = (real(a), a)
1090+
_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye, alg) = (real(a), a)
9031091
for f in [:svd_compact!, :svd_full!]
9041092
@eval begin
905-
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a, a)
906-
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a, a)
1093+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, real(a), a)
1094+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, real(a), a)
9071095
end
9081096
end
9091097

@@ -987,10 +1175,12 @@ function MatrixAlgebraKit.right_null!(
9871175
return throw(MethodError(right_null!, (a, F)))
9881176
end
9891177

990-
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
1178+
_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye) = parent(a)
1179+
_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye, alg) = parent(a)
1180+
for f in [:eigh_vals!, svd_vals!]
9911181
@eval begin
992-
_initialize_output_squareeye(::typeof($f), a::SquareEye) = parent(a)
993-
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = parent(a)
1182+
_initialize_output_squareeye(::typeof($f), a::SquareEye) = real(parent(a))
1183+
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = real(parent(a))
9941184
end
9951185
end
9961186

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
34
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
45
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
56
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

0 commit comments

Comments
 (0)