Skip to content

Commit 27e77f1

Browse files
committed
More similar, map, broadcast
1 parent b9010bb commit 27e77f1

File tree

2 files changed

+163
-34
lines changed

2 files changed

+163
-34
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.8"
4+
version = "0.1.9"
55

66
[deps]
77
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
8+
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
89
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
910
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1213

1314
[compat]
1415
DerivableInterfaces = "0.4.5"
16+
DiagonalArrays = "0.3.5"
1517
FillArrays = "1.13.0"
1618
GPUArraysCore = "0.2.0"
1719
LinearAlgebra = "1.10"

src/KroneckerArrays.jl

Lines changed: 160 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ end
250250
function Base.iszero(a::KroneckerArray)
251251
return iszero(a.a) || iszero(a.b)
252252
end
253+
function Base.isreal(a::KroneckerArray)
254+
return isreal(a.a) && isreal(a.b)
255+
end
253256
function Base.inv(a::KroneckerArray)
254257
return inv(a.a) inv(a.b)
255258
end
@@ -270,6 +273,9 @@ end
270273
function Base.:*(a::KroneckerArray, b::Number)
271274
return a.a (a.b * b)
272275
end
276+
function Base.:/(a::KroneckerArray, b::Number)
277+
return a * inv(b)
278+
end
273279

274280
function Base.:-(a::KroneckerArray)
275281
return (-a.a) a.b
@@ -291,26 +297,79 @@ for op in (:+, :-)
291297
end
292298
end
293299

300+
using Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
301+
struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
302+
function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N}
303+
return KroneckerStyle{N,a,b}()
304+
end
305+
function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
306+
return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}()
307+
end
308+
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B}
309+
return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B))
310+
end
311+
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
312+
return KroneckerStyle{N}(
313+
BroadcastStyle(style1.a, style2.a), BroadcastStyle(style1.b, style2.b)
314+
)
315+
end
316+
function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B}
317+
ax_a = map(ax -> ax.product.a, axes(bc))
318+
ax_b = map(ax -> ax.product.b, axes(bc))
319+
bc_a = Broadcasted(A, ax_a)
320+
bc_b = Broadcasted(B, ax_b)
321+
a = similar(bc_a, elt)
322+
b = similar(bc_b, elt)
323+
return a b
324+
end
325+
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle})
326+
return throw(
327+
ArgumentError(
328+
"Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
329+
),
330+
)
331+
end
332+
333+
function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...)
334+
return throw(
335+
ArgumentError(
336+
"Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
337+
),
338+
)
339+
end
340+
function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...)
341+
return throw(
342+
ArgumentError(
343+
"Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.",
344+
),
345+
)
346+
end
294347
function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray)
295348
dest.a .= a.a
296349
dest.b .= a.b
297350
return dest
298351
end
299-
function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
300-
if a.b == b.b
301-
map!(+, dest.a, a.a, b.a)
302-
dest.b .= a.b
303-
elseif a.a == b.a
304-
dest.a .= a.a
305-
map!(+, dest.b, a.b, b.b)
306-
else
307-
throw(
308-
ArgumentError(
309-
"KroneckerArray addition is only supported when the first or second arguments match.",
310-
),
352+
for f in [:+, :-]
353+
@eval begin
354+
function Base.map!(
355+
::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray
311356
)
357+
if a.b == b.b
358+
map!($f, dest.a, a.a, b.a)
359+
dest.b .= a.b
360+
elseif a.a == b.a
361+
dest.a .= a.a
362+
map!($f, dest.b, a.b, b.b)
363+
else
364+
throw(
365+
ArgumentError(
366+
"KroneckerArray addition is only supported when the first or second arguments match.",
367+
),
368+
)
369+
end
370+
return dest
371+
end
312372
end
313-
return dest
314373
end
315374
function Base.map!(
316375
f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray
@@ -326,6 +385,11 @@ function Base.map!(
326385
dest.b .= f.f.(a.b, f.x)
327386
return dest
328387
end
388+
function Base.map!(
389+
f::Base.Fix2{typeof(/),<:Number}, dest::KroneckerArray, a::KroneckerArray
390+
)
391+
return map!(Base.Fix2(*, inv(f.x)), dest, a)
392+
end
329393

330394
using LinearAlgebra:
331395
LinearAlgebra,
@@ -343,9 +407,11 @@ using LinearAlgebra:
343407
svd,
344408
svdvals,
345409
tr
346-
diagonal(a::AbstractArray) = Diagonal(a)
347-
function diagonal(a::KroneckerArray)
348-
return Diagonal(a.a) Diagonal(a.b)
410+
411+
using DiagonalArrays: DiagonalArrays, diagonal
412+
DiagonalArrays.diagonal(a::AbstractArray) = Diagonal(a)
413+
function DiagonalArrays.diagonal(a::KroneckerArray)
414+
return diagonal(a.a) diagonal(a.b)
349415
end
350416

351417
function Base.:*(a::KroneckerArray, b::KroneckerArray)
@@ -506,6 +572,19 @@ const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
506572
const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
507573
const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B}
508574

575+
using DerivableInterfaces: DerivableInterfaces, zero!
576+
function DerivableInterfaces.zero!(a::EyeKronecker)
577+
zero!(a.b)
578+
return a
579+
end
580+
function DerivableInterfaces.zero!(a::KroneckerEye)
581+
zero!(a.a)
582+
return a
583+
end
584+
function DerivableInterfaces.zero!(a::EyeEye)
585+
return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`."))
586+
end
587+
509588
function Base.:*(a::Number, b::EyeKronecker)
510589
return b.a (a * b.b)
511590
end
@@ -580,29 +659,44 @@ end
580659
function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye)
581660
return error("Can't write in-place.")
582661
end
583-
function Base.map!(f::typeof(+), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
584-
if dest.a a.a b.a
585-
throw(
586-
ArgumentError(
587-
"KroneckerArray addition is only supported when the first or second arguments match.",
588-
),
589-
)
662+
for f in [:+, :-]
663+
@eval begin
664+
function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker)
665+
if dest.a a.a b.a
666+
throw(
667+
ArgumentError(
668+
"KroneckerArray addition is only supported when the first or second arguments match.",
669+
),
670+
)
671+
end
672+
map!($f, dest.b, a.b, b.b)
673+
return dest
674+
end
675+
function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
676+
if dest.b a.b b.b
677+
throw(
678+
ArgumentError(
679+
"KroneckerArray addition is only supported when the first or second arguments match.",
680+
),
681+
)
682+
end
683+
map!($f, dest.a, a.a, b.a)
684+
return dest
685+
end
686+
function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye)
687+
return error("Can't write in-place.")
688+
end
590689
end
591-
map!(f, dest.b, a.b, b.b)
690+
end
691+
function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker)
692+
map!(f, dest.b, a.b)
592693
return dest
593694
end
594-
function Base.map!(f::typeof(+), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye)
595-
if dest.b a.b b.b
596-
throw(
597-
ArgumentError(
598-
"KroneckerArray addition is only supported when the first or second arguments match.",
599-
),
600-
)
601-
end
695+
function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye)
602696
map!(f, dest.a, a.a, b.a)
603697
return dest
604698
end
605-
function Base.map!(f::typeof(+), dest::EyeEye, a::EyeEye, b::EyeEye)
699+
function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye)
606700
return error("Can't write in-place.")
607701
end
608702
function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker)
@@ -812,6 +906,39 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
812906
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
813907
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
814908

909+
# Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`.
910+
function Base.similar(
911+
arrayt::Type{<:SquareEyeKronecker{<:Any,<:Any,A}},
912+
elt::Type,
913+
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
914+
) where {A}
915+
ax_a = map(ax -> ax.product.a, axs)
916+
ax_b = map(ax -> ax.product.b, axs)
917+
eye_ax_a = (only(unique(ax_a)),)
918+
return Eye{elt}(eye_ax_a) similar(A, elt, ax_b)
919+
end
920+
function Base.similar(
921+
arrayt::Type{<:KroneckerSquareEye{<:Any,A}},
922+
elt::Type,
923+
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
924+
) where {A}
925+
ax_a = map(ax -> ax.product.a, axs)
926+
ax_b = map(ax -> ax.product.b, axs)
927+
eye_ax_b = (only(unique(ax_b)),)
928+
return similar(A, elt, ax_a) Eye{elt}(eye_ax_b)
929+
end
930+
function Base.similar(
931+
arrayt::Type{<:SquareEyeSquareEye},
932+
elt::Type,
933+
axs::NTuple{2,CartesianProductUnitRange{<:Integer}},
934+
)
935+
ax_a = map(ax -> ax.product.a, axs)
936+
ax_b = map(ax -> ax.product.b, axs)
937+
eye_ax_a = (only(unique(ax_a)),)
938+
eye_ax_b = (only(unique(ax_b)),)
939+
return Eye{elt}(eye_ax_a) Eye{elt}(eye_ax_b)
940+
end
941+
815942
struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm
816943
kwargs::KWargs
817944
end

0 commit comments

Comments
 (0)