@@ -539,4 +539,154 @@ function LinearAlgebra.generic_mattridiv!(
539
539
return C
540
540
end
541
541
542
+ # Supports batched factorization
543
+ abstract type GeneralizedFactorization{T} <: Factorization{T} end
544
+
545
+ function LinearAlgebra. TransposeFactorization (f:: GeneralizedFactorization )
546
+ return LinearAlgebra. TransposeFactorization {eltype(f),typeof(f)} (f)
547
+ end
548
+
549
+ function LinearAlgebra. AdjointFactorization (f:: GeneralizedFactorization )
550
+ return LinearAlgebra. AdjointFactorization {eltype(f),typeof(f)} (f)
551
+ end
552
+
553
+ const GeneralizedTransposeFactorization{T} =
554
+ LinearAlgebra. TransposeFactorization{T,<: GeneralizedFactorization{T} } where {T}
555
+ const GeneralizedAdjointFactorization{T} =
556
+ LinearAlgebra. AdjointFactorization{T,<: GeneralizedFactorization{T} } where {T}
557
+
558
+ # LU Factorization
559
+ struct GeneralizedLU{T,S<: AbstractArray ,P<: AbstractArray ,I<: Union{AbstractArray,Number} } < :
560
+ GeneralizedFactorization{T}
561
+ factors:: S
562
+ ipiv:: P
563
+ perm:: P
564
+ info:: I
565
+ end
566
+
567
+ Base. ndims (lu:: GeneralizedLU ) = ndims (lu. factors)
568
+
569
+ function GeneralizedLU (factors:: S , ipiv:: P , perm:: P , info:: I ) where {S,P,I}
570
+ @assert ndims (ipiv) == ndims (perm) == ndims (factors) - 1
571
+ @assert ndims (info) == ndims (factors) - 2
572
+ return GeneralizedLU {eltype(factors),S,P,I} (factors, ipiv, perm, info)
573
+ end
574
+
575
+ # # allow > 2 dimensions as inputs
576
+ function LinearAlgebra. lu (A:: AnyTracedRArray{T,2} , :: RowMaximum ; kwargs... ) where {T}
577
+ return lu! (copy (A), RowMaximum (); kwargs... )
578
+ end
579
+ function LinearAlgebra. lu (
580
+ A:: AnyTracedRArray{T,N} , :: RowMaximum = RowMaximum (); kwargs...
581
+ ) where {T,N}
582
+ return lu! (copy (A), RowMaximum (); kwargs... )
583
+ end
584
+
585
+ function LinearAlgebra. lu! (A:: AnyTracedRArray{T,2} , :: RowMaximum ; kwargs... ) where {T}
586
+ return _lu_overload (A, RowMaximum (); kwargs... )
587
+ end
588
+ function LinearAlgebra. lu! (A:: AnyTracedRArray{T,N} , :: RowMaximum ; kwargs... ) where {T,N}
589
+ return _lu_overload (A, RowMaximum (); kwargs... )
590
+ end
591
+
592
+ function _lu_overload (
593
+ A:: AnyTracedRArray{T,N} , :: RowMaximum ; check:: Bool = false , allowsingular:: Bool = false
594
+ ) where {T,N}
595
+ # TODO : don't ignore the check and allowsingular flags
596
+ # Batching here is in the last dimensions. `Ops.lu` expects the last dimensions
597
+ permdims = vcat (Int64[N - 1 , N], collect (Int64, 1 : (N - 2 )))
598
+ A = Ops. transpose (materialize_traced_array (A), permdims)
599
+ factors, ipiv, perm, info = Reactant. Ops. lu (A)
600
+
601
+ # Permute back to the original dimensions
602
+ perm_perm = vcat (N - 1 , collect (Int64, 1 : (N - 2 )))
603
+ factors = Ops. transpose (factors, invperm (permdims))
604
+ ipiv = Ops. transpose (ipiv, perm_perm)
605
+ perm = Ops. transpose (perm, perm_perm)
606
+ return GeneralizedLU (factors, ipiv, perm, info)
607
+ end
608
+
609
+ function LinearAlgebra. ldiv! (
610
+ lu:: GeneralizedLU{T,<:AbstractArray{T,N},P,I} , B:: AbstractArray{T,M}
611
+ ) where {T,P,I,N,M}
612
+ @assert N == M + 1
613
+ ldiv! (lu, reshape (B, size (B, 1 ), 1 , size (B)[2 : end ]. .. ))
614
+ return B
615
+ end
616
+
617
+ function LinearAlgebra. ldiv! (
618
+ lu:: GeneralizedLU{T,<:AbstractArray{T,2},P,I} , B:: AbstractArray{T,2}
619
+ ) where {T,P,I}
620
+ B .= _lu_solve_core (lu. factors, B, lu. perm)
621
+ return B
622
+ end
623
+
624
+ function LinearAlgebra. ldiv! (
625
+ lu:: GeneralizedLU{T,<:AbstractArray{T,N},P,I} , B:: AbstractArray{T,N}
626
+ ) where {T,P,I,N}
627
+ batch_shape = size (lu. factors)[3 : end ]
628
+ @assert batch_shape == size (B)[3 : end ]
629
+
630
+ permutation = vcat (collect (Int64, 3 : N), 1 , 2 )
631
+
632
+ factors = Ops. transpose (materialize_traced_array (lu. factors), permutation)
633
+ B_permuted = Ops. transpose (materialize_traced_array (B), permutation)
634
+ perm = Ops. transpose (
635
+ materialize_traced_array (lu. perm), vcat (collect (Int64, 2 : (N - 1 )), 1 )
636
+ )
637
+
638
+ res = Ops. transpose (
639
+ only (
640
+ Ops. batch (
641
+ _lu_solve_core, [factors, B_permuted, perm], collect (Int64, batch_shape)
642
+ ),
643
+ ),
644
+ invperm (permutation),
645
+ )
646
+ B .= res
647
+ return B
648
+ end
649
+
650
+ for f_wrapper in (LinearAlgebra. TransposeFactorization, LinearAlgebra. AdjointFactorization),
651
+ aType in (:AbstractVecOrMat , :AbstractArray )
652
+
653
+ @eval function LinearAlgebra. ldiv! (lu:: $ (f_wrapper){<: Any ,<: GeneralizedLU }, B:: $aType )
654
+ # TODO : implement this
655
+ error (" `$(f_wrapper) ` is not supported yet for LU." )
656
+ return nothing
657
+ end
658
+ end
659
+
660
+ function _lu_solve_core (factors:: AbstractMatrix , B:: AbstractMatrix , perm:: AbstractVector )
661
+ permuted_B = B[Int64 .(perm), :]
662
+ return UpperTriangular (factors) \ (UnitLowerTriangular (factors) \ permuted_B)
663
+ end
664
+
665
+ # Overload \ to support batched factorization
666
+ for T in (
667
+ :GeneralizedFactorization ,
668
+ :GeneralizedTransposeFactorization ,
669
+ :GeneralizedAdjointFactorization ,
670
+ ),
671
+ aType in (:AbstractVecOrMat , :AbstractArray )
672
+
673
+ @eval Base.:(\ )(F:: $T , B:: $aType ) = _overloaded_backslash (F, B)
674
+ end
675
+
676
+ function _overloaded_backslash (F:: GeneralizedFactorization , B:: AbstractArray )
677
+ return ldiv! (
678
+ F, LinearAlgebra. copy_similar (B, typeof (oneunit (eltype (F)) \ oneunit (eltype (B))))
679
+ )
680
+ end
681
+
682
+ function _overloaded_backslash (F:: GeneralizedTransposeFactorization , B:: AbstractArray )
683
+ return conj! (adjoint (F. parent) \ conj .(B))
684
+ end
685
+
686
+ function _overloaded_backslash (F:: GeneralizedAdjointFactorization , B:: AbstractArray )
687
+ return ldiv! (
688
+ F, LinearAlgebra. copy_similar (B, typeof (oneunit (eltype (F)) \ oneunit (eltype (B))))
689
+ )
690
+ end
691
+
542
692
end
0 commit comments