@@ -176,14 +176,16 @@ Base.all(f, A::ArrayPartition) = all(f, (all(f, x) for x in A.x))
176176Base. all (f:: Function , A:: ArrayPartition ) = all ((all (f, x) for x in A. x))
177177Base. all (A:: ArrayPartition ) = all (identity, A)
178178
179- function Base. copyto! (dest:: AbstractArray , A:: ArrayPartition )
180- @assert length (dest) == length (A)
181- cur = 1
182- @inbounds for i in 1 : length (A. x)
183- dest[cur: (cur + length (A. x[i]) - 1 )] .= vec (A. x[i])
184- cur += length (A. x[i])
179+ for type in [AbstractArray, SparseArrays. AbstractCompressedVector, PermutedDimsArray]
180+ @eval function Base. copyto! (dest:: $ (type), A:: ArrayPartition )
181+ @assert length (dest) == length (A)
182+ cur = 1
183+ @inbounds for i in 1 : length (A. x)
184+ dest[cur: (cur + length (A. x[i]) - 1 )] .= vec (A. x[i])
185+ cur += length (A. x[i])
186+ end
187+ dest
185188 end
186- dest
187189end
188190
189191function Base. copyto! (A:: ArrayPartition , src:: ArrayPartition )
@@ -419,30 +421,38 @@ end
419421
420422ArrayInterface. zeromatrix (A:: ArrayPartition ) = ArrayInterface. zeromatrix (Vector (A))
421423
422- function LinearAlgebra. ldiv! (A:: Factorization , b:: ArrayPartition )
423- (x = ldiv! (A, Array (b)); copyto! (b, x))
424+ function __get_subtypes_in_module (mod, supertype; include_supertype = true , all= false , except= [])
425+ return filter ([getproperty (mod, name) for name in names (mod; all) if ! in (name, except)]) do value
426+ return value isa Type && (value <: supertype ) && (include_supertype || value != supertype) && ! in (value, except)
427+ end
424428end
425429
426- @static if VERSION >= v " 1.9"
427- function LinearAlgebra. ldiv! (A:: LinearAlgebra.SVD{T, Tr, M} ,
428- b:: ArrayPartition ) where {Tr, T, M <: AbstractArray{T} }
430+ for factorization in vcat (__get_subtypes_in_module (LinearAlgebra, Factorization; include_supertype = false , all= true , except= [:LU , :LAPACKFactorizations ]), LDLt{T,<: SymTridiagonal{T,V} where {V<: AbstractVector{T} }} where {T})
431+ @eval function LinearAlgebra. ldiv! (A:: T , b:: ArrayPartition ) where {T<: $factorization }
429432 (x = ldiv! (A, Array (b)); copyto! (b, x))
430433 end
434+ end
431435
432- function LinearAlgebra. ldiv! (A:: LinearAlgebra.QRCompactWY{T, M, C} ,
433- b:: ArrayPartition ) where {
434- T <: Union{Float32, Float64, ComplexF64, ComplexF32} ,
435- M <: AbstractMatrix{T} ,
436- C <: AbstractMatrix{T} ,
437- }
438- (x = ldiv! (A, Array (b)); copyto! (b, x))
439- end
436+ function LinearAlgebra. ldiv! (A:: LinearAlgebra.SVD{T, Tr, M} ,
437+ b:: ArrayPartition ) where {Tr, T, M <: AbstractArray{T} }
438+ (x = ldiv! (A, Array (b)); copyto! (b, x))
440439end
441440
442- function LinearAlgebra. ldiv! (A:: LU , b:: ArrayPartition )
443- LinearAlgebra. _ipiv_rows! (A, 1 : length (A. ipiv), b)
444- ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), b))
445- return b
441+ function LinearAlgebra. ldiv! (A:: LinearAlgebra.QRCompactWY{T, M, C} ,
442+ b:: ArrayPartition ) where {
443+ T <: Union{Float32, Float64, ComplexF64, ComplexF32} ,
444+ M <: AbstractMatrix{T} ,
445+ C <: AbstractMatrix{T} ,
446+ }
447+ (x = ldiv! (A, Array (b)); copyto! (b, x))
448+ end
449+
450+ for type in [LU, LU{T,Tridiagonal{T,V}} where {T,V}]
451+ @eval function LinearAlgebra. ldiv! (A:: $type , b:: ArrayPartition )
452+ LinearAlgebra. _ipiv_rows! (A, 1 : length (A. ipiv), b)
453+ ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), b))
454+ return b
455+ end
446456end
447457
448458# block matrix indexing
@@ -458,78 +468,31 @@ end
458468# [U11 U12 U13] [ b1 ]
459469# [ 0 U22 U23] \ [ b2 ]
460470# [ 0 0 U33] [ b3 ]
461- function LinearAlgebra. ldiv! (A:: UnitUpperTriangular , bb:: ArrayPartition )
462- A = A. data
463- n = npartitions (bb)
464- b = bb. x
465- lens = map (length, b)
466- @inbounds for j in n: - 1 : 1
467- Ajj = UnitUpperTriangular (getblock (A, lens, j, j))
468- xj = ldiv! (Ajj, vec (b[j]))
469- for i in (j - 1 ): - 1 : 1
470- Aij = getblock (A, lens, i, j)
471- # bi = -Aij * xj + bi
472- mul! (vec (b[i]), Aij, xj, - 1 , true )
473- end
474- end
475- return bb
476- end
477-
478- function LinearAlgebra. ldiv! (A:: UpperTriangular , bb:: ArrayPartition )
479- A = A. data
480- n = npartitions (bb)
481- b = bb. x
482- lens = map (length, b)
483- @inbounds for j in n: - 1 : 1
484- Ajj = UpperTriangular (getblock (A, lens, j, j))
485- xj = ldiv! (Ajj, vec (b[j]))
486- for i in (j - 1 ): - 1 : 1
487- Aij = getblock (A, lens, i, j)
488- # bi = -Aij * xj + bi
489- mul! (vec (b[i]), Aij, xj, - 1 , true )
490- end
491- end
492- return bb
493- end
494-
495- function LinearAlgebra. ldiv! (A:: UnitLowerTriangular , bb:: ArrayPartition )
496- A = A. data
497- n = npartitions (bb)
498- b = bb. x
499- lens = map (length, b)
500- @inbounds for j in 1 : n
501- Ajj = UnitLowerTriangular (getblock (A, lens, j, j))
502- xj = ldiv! (Ajj, vec (b[j]))
503- for i in (j + 1 ): n
504- Aij = getblock (A, lens, i, j)
505- # bi = -Aij * xj + b[i]
506- mul! (vec (b[i]), Aij, xj, - 1 , true )
471+ for basetype in [UnitUpperTriangular, UpperTriangular, UnitLowerTriangular, LowerTriangular]
472+ for type in [basetype, basetype{T, <: Adjoint{T} } where {T}, basetype{T, <: Transpose{T} } where {T}]
473+ j_iter, i_iter = if basetype <: UnitUpperTriangular || basetype <: UpperTriangular
474+ (:(n: - 1 : 1 ), :(j- 1 : - 1 : 1 ))
475+ else
476+ (:(1 : n), :((j+ 1 ): n))
507477 end
508- end
509- return bb
510- end
511- function _ldiv! (A :: LowerTriangular , bb :: ArrayPartition )
512- A = A . data
513- n = npartitions (bb)
514- b = bb . x
515- lens = map (length, b )
516- @inbounds for j in 1 : n
517- Ajj = LowerTriangular ( getblock (A, lens, j , j) )
518- xj = ldiv! (Ajj, vec (b[j]))
519- for i in (j + 1 ) : n
520- Aij = getblock (A, lens, i, j)
521- # bi = -Aij * xj + b[i]
522- mul! ( vec (b[i]), Aij, xj, - 1 , true )
478+ @eval function LinearAlgebra . ldiv! (A :: $type , bb :: ArrayPartition )
479+ A = A . data
480+ n = npartitions (bb)
481+ b = bb . x
482+ lens = map (length, b)
483+ @inbounds for j in $ j_iter
484+ Ajj = $ basetype ( getblock (A, lens, j, j))
485+ xj = ldiv! (Ajj, vec (b[j]) )
486+ for i in $ i_iter
487+ Aij = getblock (A, lens, i , j)
488+ # bi = -Aij * xj + bi
489+ mul! ( vec (b[i]), Aij, xj, - 1 , true )
490+ end
491+ end
492+ return bb
523493 end
524494 end
525- return bb
526- end
527-
528- function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:LinearAlgebra.Adjoint{T}} ,
529- bb:: ArrayPartition ) where {T}
530- _ldiv! (A, bb)
531495end
532- LinearAlgebra. ldiv! (A:: LowerTriangular , bb:: ArrayPartition ) = _ldiv! (A, bb)
533496
534497# TODO : optimize
535498function LinearAlgebra. _ipiv_rows! (A:: LU , order:: OrdinalRange , B:: ArrayPartition )
0 commit comments