@@ -234,6 +234,62 @@ function Base.:*(a::KroneckerArray, b::Number)
234234 return a. a ⊗ (a. b * b)
235235end
236236
237+ function Base.:- (a:: KroneckerArray )
238+ return (- a. a) ⊗ a. b
239+ end
240+ for op in (:+ , :- )
241+ @eval begin
242+ function Base. $op (a:: KroneckerArray , b:: KroneckerArray )
243+ if a. b == b. b
244+ return $ op (a. a, b. a) ⊗ a. b
245+ elseif a. a == b. a
246+ return a. a ⊗ $ op (a. b, b. b)
247+ end
248+ return throw (
249+ ArgumentError (
250+ " KroneckerArray addition is only supported when the first or secord arguments match." ,
251+ ),
252+ )
253+ end
254+ end
255+ end
256+
257+ function Base. map! (:: typeof (identity), dest:: KroneckerArray , a:: KroneckerArray )
258+ dest. a .= a. a
259+ dest. b .= a. b
260+ return dest
261+ end
262+ function Base. map! (:: typeof (+ ), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray )
263+ if a. b == b. b
264+ map! (+ , dest. a, a. a, b. a)
265+ dest. b .= a. b
266+ elseif a. a == b. a
267+ dest. a .= a. a
268+ map! (+ , dest. b, a. b, b. b)
269+ else
270+ throw (
271+ ArgumentError (
272+ " KroneckerArray addition is only supported when the first or second arguments match." ,
273+ ),
274+ )
275+ end
276+ return dest
277+ end
278+ function Base. map! (
279+ f:: Base.Fix1{typeof(*),<:Number} , dest:: KroneckerArray , a:: KroneckerArray
280+ )
281+ dest. a .= f. f .(f. x, a. a)
282+ dest. b .= a. b
283+ return dest
284+ end
285+ function Base. map! (
286+ f:: Base.Fix2{typeof(*),<:Number} , dest:: KroneckerArray , a:: KroneckerArray
287+ )
288+ dest. a .= a. a
289+ dest. b .= f. f .(a. b, f. x)
290+ return dest
291+ end
292+
237293using LinearAlgebra:
238294 LinearAlgebra,
239295 Diagonal,
@@ -346,67 +402,138 @@ function LinearAlgebra.lq(a::KroneckerArray)
346402 return KroneckerLQ (Fa. L ⊗ Fb. L, Fa. Q ⊗ Fb. Q)
347403end
348404
349- function Base.:- (a:: KroneckerArray )
405+ using DerivableInterfaces: DerivableInterfaces, zero!
406+ function DerivableInterfaces. zero! (a:: KroneckerArray )
407+ zero! (a. a)
408+ zero! (a. b)
409+ return a
410+ end
411+
412+ using FillArrays: Eye
413+ const EyeKronecker{T,A<: Eye{T} ,B<: AbstractMatrix{T} } = KroneckerMatrix{T,A,B}
414+ const KroneckerEye{T,A<: AbstractMatrix{T} ,B<: Eye{T} } = KroneckerMatrix{T,A,B}
415+ const EyeEye{T,A<: Eye{T} ,B<: Eye{T} } = KroneckerMatrix{T,A,B}
416+
417+ function Base.:* (a:: Number , b:: EyeKronecker )
418+ return b. a ⊗ (a * b. b)
419+ end
420+ function Base.:* (a:: Number , b:: KroneckerEye )
421+ return (a * b. a) ⊗ b. b
422+ end
423+ function Base.:* (a:: Number , b:: EyeEye )
424+ return (a * b. a) ⊗ b. b
425+ end
426+ function Base.:* (a:: EyeKronecker , b:: Number )
427+ return a. a ⊗ (a. b * b)
428+ end
429+ function Base.:* (a:: KroneckerEye , b:: Number )
430+ return (a. a * b) ⊗ a. b
431+ end
432+ function Base.:* (a:: EyeEye , b:: Number )
433+ return a. a ⊗ (a. b * b)
434+ end
435+
436+ function Base.:- (a:: EyeKronecker )
437+ return a. a ⊗ (- a. b)
438+ end
439+ function Base.:- (a:: KroneckerEye )
440+ return (- a. a) ⊗ a. b
441+ end
442+ function Base.:- (a:: EyeEye )
350443 return (- a. a) ⊗ a. b
351444end
352445for op in (:+ , :- )
353446 @eval begin
354- function Base. $op (a:: KroneckerArray , b:: KroneckerArray )
355- if a. b == b. b
356- return $ op (a. a, b. a) ⊗ a. b
357- elseif a. a == b. a
358- return a. a ⊗ $ op (a. b, b. b)
447+ function Base. $op (a:: EyeKronecker , b:: EyeKronecker )
448+ if a. a ≠ b. a
449+ return throw (
450+ ArgumentError (
451+ " KroneckerArray addition is only supported when the first or secord arguments match." ,
452+ ),
453+ )
359454 end
360- return throw (
361- ArgumentError (
362- " KroneckerArray addition is only supported when the first or secord arguments match." ,
363- ),
364- )
455+ return a. a ⊗ $ op (a. b, b. b)
456+ end
457+ function Base. $op (a:: KroneckerEye , b:: KroneckerEye )
458+ if a. b ≠ b. b
459+ return throw (
460+ ArgumentError (
461+ " KroneckerArray addition is only supported when the first or secord arguments match." ,
462+ ),
463+ )
464+ end
465+ return $ op (a. a, b. a) ⊗ a. b
466+ end
467+ function Base. $op (a:: EyeEye , b:: EyeEye )
468+ if a. b ≠ b. b
469+ return throw (
470+ ArgumentError (
471+ " KroneckerArray addition is only supported when the first or secord arguments match." ,
472+ ),
473+ )
474+ end
475+ return $ op (a. a, b. a) ⊗ a. b
365476 end
366477 end
367478end
368479
369- function Base. map! (:: typeof (identity), dest:: KroneckerArray , a:: KroneckerArray )
370- dest. a .= a. a
480+ function Base. map! (:: typeof (identity), dest:: EyeKronecker , a:: EyeKronecker )
371481 dest. b .= a. b
372482 return dest
373483end
374- function Base. map! (:: typeof (+ ), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray )
375- if a. b == b. b
376- map! (+ , dest. a, a. a, b. a)
377- dest. b .= a. b
378- elseif a. a == b. a
379- dest. a .= a. a
380- map! (+ , dest. b, a. b, b. b)
381- else
484+ function Base. map! (:: typeof (identity), dest:: KroneckerEye , a:: KroneckerEye )
485+ dest. a .= a. a
486+ return dest
487+ end
488+ function Base. map! (:: typeof (identity), dest:: EyeEye , a:: EyeEye )
489+ return error (" Can't write in-place." )
490+ end
491+ function Base. map! (f:: typeof (+ ), dest:: EyeKronecker , a:: EyeKronecker , b:: EyeKronecker )
492+ if dest. a ≠ a. a ≠ b. a
382493 throw (
383494 ArgumentError (
384495 " KroneckerArray addition is only supported when the first or second arguments match." ,
385496 ),
386497 )
387498 end
499+ map! (f, dest. b, a. b, b. b)
388500 return dest
389501end
390- function Base. map! (
391- f:: Base.Fix1{typeof(*),<:Number} , dest:: KroneckerArray , a:: KroneckerArray
392- )
393- dest. a .= f. x .* a. a
394- dest. b .= a. b
502+ function Base. map! (f:: typeof (+ ), dest:: KroneckerEye , a:: KroneckerEye , b:: KroneckerEye )
503+ if dest. b ≠ a. b ≠ b. b
504+ throw (
505+ ArgumentError (
506+ " KroneckerArray addition is only supported when the first or second arguments match." ,
507+ ),
508+ )
509+ end
510+ map! (f, dest. a, a. a, b. a)
395511 return dest
396512end
397- function Base. map! (
398- f :: Base.Fix2{typeof(*),<:Number} , dest :: KroneckerArray , a :: KroneckerArray
399- )
400- dest . a . = a . a
401- dest. b .= a . b .* f. x
513+ function Base. map! (f :: typeof ( + ), dest :: EyeEye , a :: EyeEye , b :: EyeEye )
514+ return error ( " Can't write in-place. " )
515+ end
516+ function Base . map! (f :: Base.Fix1{typeof(*),<:Number} , dest :: EyeKronecker , a :: EyeKronecker )
517+ dest. b .= f . f .( f. x, a . b)
402518 return dest
403519end
404-
405- using DerivableInterfaces: DerivableInterfaces, zero!
406- function DerivableInterfaces. zero! (a:: KroneckerArray )
407- zero! (a. a)
408- zero! (a. b)
409- return a
520+ function Base. map! (f:: Base.Fix1{typeof(*),<:Number} , dest:: KroneckerEye , a:: KroneckerEye )
521+ dest. a .= f. f .(f. x, a. a)
522+ return dest
523+ end
524+ function Base. map! (f:: Base.Fix1{typeof(*),<:Number} , dest:: EyeEye , a:: EyeEye )
525+ return error (" Can't write in-place." )
526+ end
527+ function Base. map! (f:: Base.Fix2{typeof(*),<:Number} , dest:: EyeKronecker , a:: EyeKronecker )
528+ dest. b .= f. f .(a. b, f. x)
529+ return dest
530+ end
531+ function Base. map! (f:: Base.Fix2{typeof(*),<:Number} , dest:: KroneckerEye , a:: KroneckerEye )
532+ dest. a .= f. f .(a. a, f. x)
533+ return dest
534+ end
535+ function Base. map! (f:: Base.Fix2{typeof(*),<:Number} , dest:: EyeEye , a:: EyeEye )
536+ return error (" Can't write in-place." )
410537end
411538
412539using MatrixAlgebraKit:
@@ -447,6 +574,38 @@ struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm
447574 b:: B
448575end
449576
577+ using MatrixAlgebraKit:
578+ copy_input,
579+ eig_full,
580+ eigh_full,
581+ qr_compact,
582+ qr_full,
583+ left_polar,
584+ lq_compact,
585+ lq_full,
586+ right_polar,
587+ svd_compact,
588+ svd_full
589+
590+ for f in [
591+ :eig_full ,
592+ :eigh_full ,
593+ :qr_compact ,
594+ :qr_full ,
595+ :left_polar ,
596+ :lq_compact ,
597+ :lq_full ,
598+ :right_polar ,
599+ :svd_compact ,
600+ :svd_full ,
601+ ]
602+ @eval begin
603+ function MatrixAlgebraKit. copy_input (:: typeof ($ f), a:: KroneckerMatrix )
604+ return copy_input ($ f, a. a) ⊗ copy_input ($ f, a. b)
605+ end
606+ end
607+ end
608+
450609for f in (:eig , :eigh , :lq , :qr , :polar , :svd )
451610 ff = Symbol (" default_" , f, " _algorithm" )
452611 @eval begin
@@ -530,4 +689,75 @@ for f in (:left_null!, :right_null!)
530689 end
531690end
532691
692+ # Special case for `FillArrays.Eye` matrices.
693+ struct EyeAlgorithm <: AbstractAlgorithm end
694+
695+ for f in [
696+ :eig_full ,
697+ :eigh_full ,
698+ :qr_compact ,
699+ :qr_full ,
700+ :left_polar ,
701+ :lq_compact ,
702+ :lq_full ,
703+ :right_polar ,
704+ :svd_compact ,
705+ :svd_full ,
706+ ]
707+ @eval begin
708+ MatrixAlgebraKit. copy_input (:: typeof ($ f), a:: Eye ) = a
709+ end
710+ end
711+
712+ for f in (:eig , :eigh , :lq , :qr , :polar , :svd )
713+ ff = Symbol (" default_" , f, " _algorithm" )
714+ @eval begin
715+ function MatrixAlgebraKit. $ff (a:: Eye ; kwargs... )
716+ return EyeAlgorithm ()
717+ end
718+ end
719+ end
720+
721+ for f in (
722+ :eig_full! ,
723+ :eigh_full! ,
724+ :qr_compact! ,
725+ :qr_full! ,
726+ :left_polar! ,
727+ :lq_compact! ,
728+ :lq_full! ,
729+ :right_polar! ,
730+ )
731+ @eval begin
732+ nfactors (:: typeof ($ f)) = 2
733+ end
734+ end
735+ for f in (:svd_compact! , :svd_full! )
736+ @eval begin
737+ nfactors (:: typeof ($ f)) = 3
738+ end
739+ end
740+
741+ for f in (
742+ :eig_full! ,
743+ :eigh_full! ,
744+ :qr_compact! ,
745+ :qr_full! ,
746+ :left_polar! ,
747+ :lq_compact! ,
748+ :lq_full! ,
749+ :right_polar! ,
750+ :svd_compact! ,
751+ :svd_full! ,
752+ )
753+ @eval begin
754+ function MatrixAlgebraKit. initialize_output (:: typeof ($ f), a:: Eye , alg:: EyeAlgorithm )
755+ return ntuple (_ -> a, nfactors ($ f))
756+ end
757+ function MatrixAlgebraKit. $f (a:: Eye , F, alg:: EyeAlgorithm ; kwargs... )
758+ return ntuple (_ -> a, nfactors ($ f))
759+ end
760+ end
761+ end
762+
533763end
0 commit comments