@@ -299,6 +299,70 @@ common_number(a, b) =
299299# # Linear Algebra
300300
301301ArrayInterface. zeromatrix (A:: ArrayPartition ) = ArrayInterface. zeromatrix (reduce (vcat,vec .(A. x)))
302- LinearAlgebra. ldiv! (A:: LinearAlgebra.LU ,b:: ArrayPartition ) = ldiv! (A,Array (b))
303- LinearAlgebra. ldiv! (A:: LinearAlgebra.QR ,b:: ArrayPartition ) = ldiv! (A,Array (b))
304- LinearAlgebra. ldiv! (A:: LinearAlgebra.SVD ,b:: ArrayPartition ) = ldiv! (A,Array (b))
302+
303+ LinearAlgebra. ldiv! (A:: Factorization , b:: ArrayPartition ) = (x = ldiv! (A, Array (b)); copyto! (b, x))
304+ function LinearAlgebra. ldiv! (A:: LU , b:: ArrayPartition )
305+ LinearAlgebra. _ipiv_rows! (A, 1 : length (A. ipiv), b)
306+ ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), b))
307+ return b
308+ end
309+
310+ # block matrix indexing
311+ @inbounds function getblock (A, lens, i, j)
312+ ii1 = i == 1 ? 0 : sum (ii-> lens[ii], 1 : i- 1 )
313+ jj1 = j == 1 ? 0 : sum (ii-> lens[ii], 1 : j- 1 )
314+ ij1 = CartesianIndex (ii1, jj1)
315+ cc1 = CartesianIndex ((1 , 1 ))
316+ inc = CartesianIndex (lens[i], lens[j])
317+ return @view A[(ij1+ cc1): (ij1+ inc)]
318+ end
319+ # fast ldiv for UpperTriangular and UnitLowerTriangular
320+ # [U11 U12 U13] [ b1 ]
321+ # [ 0 U22 U23] \ [ b2 ]
322+ # [ 0 0 U33] [ b3 ]
323+ function LinearAlgebra. ldiv! (A:: T , bb:: ArrayPartition ) where T<: Union{UnitUpperTriangular,UpperTriangular}
324+ A = A. data
325+ n = npartitions (bb)
326+ b = bb. x
327+ lens = map (length, b)
328+ @inbounds for j in n: - 1 : 1
329+ Ajj = T (getblock (A, lens, j, j))
330+ xj = ldiv! (Ajj, b[j])
331+ for i in j- 1 : - 1 : 1
332+ Aij = getblock (A, lens, i, j)
333+ # bi = -Aij * xj + bi
334+ mul! (b[i], Aij, xj, - 1 , true )
335+ end
336+ end
337+ return bb
338+ end
339+
340+ function LinearAlgebra. ldiv! (A:: T , bb:: ArrayPartition ) where T<: Union{UnitLowerTriangular,LowerTriangular}
341+ A = A. data
342+ n = npartitions (bb)
343+ b = bb. x
344+ lens = map (length, b)
345+ @inbounds for j in 1 : n
346+ Ajj = T (getblock (A, lens, j, j))
347+ xj = ldiv! (Ajj, b[j])
348+ for i in j+ 1 : n
349+ Aij = getblock (A, lens, i, j)
350+ # bi = -Aij * xj + b[i]
351+ mul! (b[i], Aij, xj, - 1 , true )
352+ end
353+ end
354+ return bb
355+ end
356+ # TODO : optimize
357+ function LinearAlgebra. _ipiv_rows! (A:: LU , order:: OrdinalRange , B:: ArrayPartition )
358+ for i = order
359+ if i != A. ipiv[i]
360+ LinearAlgebra. _swap_rows! (B, i, A. ipiv[i])
361+ end
362+ end
363+ return B
364+ end
365+ function LinearAlgebra. _swap_rows! (B:: ArrayPartition , i:: Integer , j:: Integer )
366+ B[i], B[j] = B[j], B[i]
367+ return B
368+ end
0 commit comments