250
250
function Base. iszero (a:: KroneckerArray )
251
251
return iszero (a. a) || iszero (a. b)
252
252
end
253
+ function Base. isreal (a:: KroneckerArray )
254
+ return isreal (a. a) && isreal (a. b)
255
+ end
253
256
function Base. inv (a:: KroneckerArray )
254
257
return inv (a. a) ⊗ inv (a. b)
255
258
end
270
273
function Base.:* (a:: KroneckerArray , b:: Number )
271
274
return a. a ⊗ (a. b * b)
272
275
end
276
+ function Base.:/ (a:: KroneckerArray , b:: Number )
277
+ return a * inv (b)
278
+ end
273
279
274
280
function Base.:- (a:: KroneckerArray )
275
281
return (- a. a) ⊗ a. b
@@ -291,26 +297,82 @@ for op in (:+, :-)
291
297
end
292
298
end
293
299
300
+ using Base. Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
301
+ struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
302
+ function KroneckerStyle {N} (a:: BroadcastStyle , b:: BroadcastStyle ) where {N}
303
+ return KroneckerStyle {N,a,b} ()
304
+ end
305
+ function KroneckerStyle (a:: AbstractArrayStyle{N} , b:: AbstractArrayStyle{N} ) where {N}
306
+ return KroneckerStyle {N} (a, b)
307
+ end
308
+ function KroneckerStyle {N,A,B} (v:: Val{M} ) where {N,A,B,M}
309
+ return KroneckerStyle {M,typeof(A)(v),typeof(B)(v)} ()
310
+ end
311
+ function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any,N,A,B}} ) where {N,A,B}
312
+ return KroneckerStyle {N} (BroadcastStyle (A), BroadcastStyle (B))
313
+ end
314
+ function Base. BroadcastStyle (style1:: KroneckerStyle{N} , style2:: KroneckerStyle{N} ) where {N}
315
+ return KroneckerStyle {N} (
316
+ BroadcastStyle (style1. a, style2. a), BroadcastStyle (style1. b, style2. b)
317
+ )
318
+ end
319
+ function Base. similar (bc:: Broadcasted{<:KroneckerStyle{N,A,B}} , elt:: Type ) where {N,A,B}
320
+ ax_a = map (ax -> ax. product. a, axes (bc))
321
+ ax_b = map (ax -> ax. product. b, axes (bc))
322
+ bc_a = Broadcasted (A, nothing , (), ax_a)
323
+ bc_b = Broadcasted (B, nothing , (), ax_b)
324
+ a = similar (bc_a, elt)
325
+ b = similar (bc_b, elt)
326
+ return a ⊗ b
327
+ end
328
+ function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:KroneckerStyle} )
329
+ return throw (
330
+ ArgumentError (
331
+ " Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
332
+ ),
333
+ )
334
+ end
335
+
336
+ function Base. map (f, a1:: KroneckerArray , a_rest:: KroneckerArray... )
337
+ return throw (
338
+ ArgumentError (
339
+ " Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
340
+ ),
341
+ )
342
+ end
343
+ function Base. map! (f, dest:: KroneckerArray , a1:: KroneckerArray , a_rest:: KroneckerArray... )
344
+ return throw (
345
+ ArgumentError (
346
+ " Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
347
+ ),
348
+ )
349
+ end
294
350
function Base. map! (:: typeof (identity), dest:: KroneckerArray , a:: KroneckerArray )
295
351
dest. a .= a. a
296
352
dest. b .= a. b
297
353
return dest
298
354
end
299
- function Base. map! (:: typeof (+ ), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray )
300
- if a. b == b. b
301
- map! (+ , dest. a, a. a, b. a)
302
- dest. b .= a. b
303
- elseif a. a == b. a
304
- dest. a .= a. a
305
- map! (+ , dest. b, a. b, b. b)
306
- else
307
- throw (
308
- ArgumentError (
309
- " KroneckerArray addition is only supported when the first or second arguments match." ,
310
- ),
355
+ for f in [:+ , :- ]
356
+ @eval begin
357
+ function Base. map! (
358
+ :: typeof ($ f), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray
311
359
)
360
+ if a. b == b. b
361
+ map! ($ f, dest. a, a. a, b. a)
362
+ dest. b .= a. b
363
+ elseif a. a == b. a
364
+ dest. a .= a. a
365
+ map! ($ f, dest. b, a. b, b. b)
366
+ else
367
+ throw (
368
+ ArgumentError (
369
+ " KroneckerArray addition is only supported when the first or second arguments match." ,
370
+ ),
371
+ )
372
+ end
373
+ return dest
374
+ end
312
375
end
313
- return dest
314
376
end
315
377
function Base. map! (
316
378
f:: Base.Fix1{typeof(*),<:Number} , dest:: KroneckerArray , a:: KroneckerArray
@@ -326,6 +388,16 @@ function Base.map!(
326
388
dest. b .= f. f .(a. b, f. x)
327
389
return dest
328
390
end
391
+ function Base. map! (
392
+ f:: Base.Fix2{typeof(/),<:Number} , dest:: KroneckerArray , a:: KroneckerArray
393
+ )
394
+ return map! (Base. Fix2 (* , inv (f. x)), dest, a)
395
+ end
396
+ function Base. map! (:: typeof (conj), dest:: KroneckerArray , a:: KroneckerArray )
397
+ dest. a .= conj .(a. a)
398
+ dest. b .= conj .(a. b)
399
+ return dest
400
+ end
329
401
330
402
using LinearAlgebra:
331
403
LinearAlgebra,
@@ -343,9 +415,10 @@ using LinearAlgebra:
343
415
svd,
344
416
svdvals,
345
417
tr
346
- diagonal (a:: AbstractArray ) = Diagonal (a)
347
- function diagonal (a:: KroneckerArray )
348
- return Diagonal (a. a) ⊗ Diagonal (a. b)
418
+
419
+ using DiagonalArrays: DiagonalArrays, diagonal
420
+ function DiagonalArrays. diagonal (a:: KroneckerArray )
421
+ return diagonal (a. a) ⊗ diagonal (a. b)
349
422
end
350
423
351
424
function Base.:* (a:: KroneckerArray , b:: KroneckerArray )
@@ -372,6 +445,23 @@ function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
372
445
return norm (a. a, p) ⊗ norm (a. b, p)
373
446
end
374
447
448
+ function Base. real (a:: KroneckerArray )
449
+ if iszero (imag (a. a)) || iszero (imag (a. b))
450
+ return real (a. a) ⊗ real (a. b)
451
+ elseif iszero (real (a. a)) || iszero (real (a. b))
452
+ return - imag (a. a) ⊗ imag (a. b)
453
+ end
454
+ return real (a. a) ⊗ real (a. b) - imag (a. a) ⊗ imag (a. b)
455
+ end
456
+ function Base. imag (a:: KroneckerArray )
457
+ if iszero (imag (a. a)) || iszero (real (a. b))
458
+ return real (a. a) ⊗ imag (a. b)
459
+ elseif iszero (real (a. a)) || iszero (imag (a. b))
460
+ return imag (a. a) ⊗ real (a. b)
461
+ end
462
+ return real (a. a) ⊗ imag (a. b) + imag (a. a) ⊗ real (a. b)
463
+ end
464
+
375
465
using MatrixAlgebraKit: MatrixAlgebraKit, diagview
376
466
function MatrixAlgebraKit. diagview (a:: KroneckerMatrix )
377
467
return diagview (a. a) ⊗ diagview (a. b)
@@ -506,6 +596,19 @@ const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
506
596
const KroneckerEye{T,A<: AbstractMatrix{T} ,B<: Eye{T} } = KroneckerMatrix{T,A,B}
507
597
const EyeEye{T,A<: Eye{T} ,B<: Eye{T} } = KroneckerMatrix{T,A,B}
508
598
599
+ using DerivableInterfaces: DerivableInterfaces, zero!
600
+ function DerivableInterfaces. zero! (a:: EyeKronecker )
601
+ zero! (a. b)
602
+ return a
603
+ end
604
+ function DerivableInterfaces. zero! (a:: KroneckerEye )
605
+ zero! (a. a)
606
+ return a
607
+ end
608
+ function DerivableInterfaces. zero! (a:: EyeEye )
609
+ return throw (ArgumentError (" Can't zero out `Eye ⊗ Eye`." ))
610
+ end
611
+
509
612
function Base.:* (a:: Number , b:: EyeKronecker )
510
613
return b. a ⊗ (a * b. b)
511
614
end
@@ -580,29 +683,44 @@ end
580
683
function Base. map! (:: typeof (identity), dest:: EyeEye , a:: EyeEye )
581
684
return error (" Can't write in-place." )
582
685
end
583
- function Base. map! (f:: typeof (+ ), dest:: EyeKronecker , a:: EyeKronecker , b:: EyeKronecker )
584
- if dest. a ≠ a. a ≠ b. a
585
- throw (
586
- ArgumentError (
587
- " KroneckerArray addition is only supported when the first or second arguments match." ,
588
- ),
589
- )
686
+ for f in [:+ , :- ]
687
+ @eval begin
688
+ function Base. map! (:: typeof ($ f), dest:: EyeKronecker , a:: EyeKronecker , b:: EyeKronecker )
689
+ if dest. a ≠ a. a ≠ b. a
690
+ throw (
691
+ ArgumentError (
692
+ " KroneckerArray addition is only supported when the first or second arguments match." ,
693
+ ),
694
+ )
695
+ end
696
+ map! ($ f, dest. b, a. b, b. b)
697
+ return dest
698
+ end
699
+ function Base. map! (:: typeof ($ f), dest:: KroneckerEye , a:: KroneckerEye , b:: KroneckerEye )
700
+ if dest. b ≠ a. b ≠ b. b
701
+ throw (
702
+ ArgumentError (
703
+ " KroneckerArray addition is only supported when the first or second arguments match." ,
704
+ ),
705
+ )
706
+ end
707
+ map! ($ f, dest. a, a. a, b. a)
708
+ return dest
709
+ end
710
+ function Base. map! (:: typeof ($ f), dest:: EyeEye , a:: EyeEye , b:: EyeEye )
711
+ return error (" Can't write in-place." )
712
+ end
590
713
end
591
- map! (f, dest. b, a. b, b. b)
714
+ end
715
+ function Base. map! (f:: typeof (- ), dest:: EyeKronecker , a:: EyeKronecker )
716
+ map! (f, dest. b, a. b)
592
717
return dest
593
718
end
594
- function Base. map! (f:: typeof (+ ), dest:: KroneckerEye , a:: KroneckerEye , b:: KroneckerEye )
595
- if dest. b ≠ a. b ≠ b. b
596
- throw (
597
- ArgumentError (
598
- " KroneckerArray addition is only supported when the first or second arguments match." ,
599
- ),
600
- )
601
- end
602
- map! (f, dest. a, a. a, b. a)
719
+ function Base. map! (f:: typeof (- ), dest:: KroneckerEye , a:: KroneckerEye )
720
+ map! (f, dest. a, a. a)
603
721
return dest
604
722
end
605
- function Base. map! (f:: typeof (+ ), dest:: EyeEye , a:: EyeEye , b :: EyeEye )
723
+ function Base. map! (f:: typeof (- ), dest:: EyeEye , a:: EyeEye )
606
724
return error (" Can't write in-place." )
607
725
end
608
726
function Base. map! (f:: Base.Fix1{typeof(*),<:Number} , dest:: EyeKronecker , a:: EyeKronecker )
@@ -812,6 +930,74 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
812
930
const KroneckerSquareEye{T,A<: AbstractMatrix{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
813
931
const SquareEyeSquareEye{T,A<: SquareEye{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
814
932
933
+ # Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`.
934
+ function Base. similar (
935
+ a:: SquareEyeKronecker ,
936
+ elt:: Type ,
937
+ axs:: Tuple {
938
+ CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
939
+ },
940
+ )
941
+ ax_a = map (ax -> ax. product. a, axs)
942
+ ax_b = map (ax -> ax. product. b, axs)
943
+ eye_ax_a = (only (unique (ax_a)),)
944
+ return Eye {elt} (eye_ax_a) ⊗ similar (a. b, elt, ax_b)
945
+ end
946
+ function Base. similar (
947
+ a:: KroneckerSquareEye ,
948
+ elt:: Type ,
949
+ axs:: Tuple {
950
+ CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
951
+ },
952
+ )
953
+ ax_a = map (ax -> ax. product. a, axs)
954
+ ax_b = map (ax -> ax. product. b, axs)
955
+ eye_ax_b = (only (unique (ax_b)),)
956
+ return similar (a. a, elt, ax_a) ⊗ Eye {elt} (eye_ax_b)
957
+ end
958
+ function Base. similar (
959
+ a:: SquareEyeSquareEye ,
960
+ elt:: Type ,
961
+ axs:: Tuple {
962
+ CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
963
+ },
964
+ )
965
+ ax_a = map (ax -> ax. product. a, axs)
966
+ ax_b = map (ax -> ax. product. b, axs)
967
+ eye_ax_a = (only (unique (ax_a)),)
968
+ eye_ax_b = (only (unique (ax_b)),)
969
+ return Eye {elt} (eye_ax_a) ⊗ Eye {elt} (eye_ax_b)
970
+ end
971
+
972
+ function Base. similar (
973
+ arrayt:: Type{<:SquareEyeKronecker{<:Any,<:Any,A}} ,
974
+ axs:: NTuple{2,CartesianProductUnitRange{<:Integer}} ,
975
+ ) where {A}
976
+ ax_a = map (ax -> ax. product. a, axs)
977
+ ax_b = map (ax -> ax. product. b, axs)
978
+ eye_ax_a = (only (unique (ax_a)),)
979
+ return Eye {eltype(arrayt)} (eye_ax_a) ⊗ similar (A, ax_b)
980
+ end
981
+ function Base. similar (
982
+ arrayt:: Type{<:KroneckerSquareEye{<:Any,A}} ,
983
+ axs:: NTuple{2,CartesianProductUnitRange{<:Integer}} ,
984
+ ) where {A}
985
+ ax_a = map (ax -> ax. product. a, axs)
986
+ ax_b = map (ax -> ax. product. b, axs)
987
+ eye_ax_b = (only (unique (ax_b)),)
988
+ return similar (A, ax_a) ⊗ Eye {eltype(arrayt)} (eye_ax_b)
989
+ end
990
+ function Base. similar (
991
+ arrayt:: Type{<:SquareEyeSquareEye} , axs:: NTuple{2,CartesianProductUnitRange{<:Integer}}
992
+ )
993
+ elt = eltype (arrayt)
994
+ ax_a = map (ax -> ax. product. a, axs)
995
+ ax_b = map (ax -> ax. product. b, axs)
996
+ eye_ax_a = (only (unique (ax_a)),)
997
+ eye_ax_b = (only (unique (ax_b)),)
998
+ return Eye {elt} (eye_ax_a) ⊗ Eye {elt} (eye_ax_b)
999
+ end
1000
+
815
1001
struct SquareEyeAlgorithm{KWargs<: NamedTuple } <: AbstractAlgorithm
816
1002
kwargs:: KWargs
817
1003
end
@@ -884,8 +1070,6 @@ for f in [:left_null!, :right_null!]
884
1070
end
885
1071
end
886
1072
for f in [
887
- :eig_full! ,
888
- :eigh_full! ,
889
1073
:qr_compact! ,
890
1074
:qr_full! ,
891
1075
:left_orth! ,
@@ -900,10 +1084,14 @@ for f in [
900
1084
_initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, a)
901
1085
end
902
1086
end
1087
+ _initialize_output_squareeye (:: typeof (eig_full!), a:: SquareEye ) = complex .((a, a))
1088
+ _initialize_output_squareeye (:: typeof (eig_full!), a:: SquareEye , alg) = complex .((a, a))
1089
+ _initialize_output_squareeye (:: typeof (eigh_full!), a:: SquareEye ) = (real (a), a)
1090
+ _initialize_output_squareeye (:: typeof (eigh_full!), a:: SquareEye , alg) = (real (a), a)
903
1091
for f in [:svd_compact! , :svd_full! ]
904
1092
@eval begin
905
- _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = (a, a , a)
906
- _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, a , a)
1093
+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = (a, real (a) , a)
1094
+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, real (a) , a)
907
1095
end
908
1096
end
909
1097
@@ -987,10 +1175,12 @@ function MatrixAlgebraKit.right_null!(
987
1175
return throw (MethodError (right_null!, (a, F)))
988
1176
end
989
1177
990
- for f in [:eig_vals! , :eigh_vals! , :svd_vals! ]
1178
+ _initialize_output_squareeye (:: typeof (eig_vals!), a:: SquareEye ) = parent (a)
1179
+ _initialize_output_squareeye (:: typeof (eig_vals!), a:: SquareEye , alg) = parent (a)
1180
+ for f in [:eigh_vals! , svd_vals!]
991
1181
@eval begin
992
- _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = parent (a)
993
- _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = parent (a)
1182
+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = real ( parent (a) )
1183
+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = real ( parent (a) )
994
1184
end
995
1185
end
996
1186
0 commit comments