@@ -7,8 +7,6 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=t
77 mzone = - one (T)
88 Ac = A. chunks
99 mt, nt = size (Ac)
10- iscomplex = T <: Complex
11- trans = iscomplex ? ' C' : ' T'
1210
1311 Dagger. spawn_datadeps () do
1412 for k in range (1 , min (mt, nt))
@@ -29,3 +27,97 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=t
2927
3028 return LinearAlgebra. LU {T,DMatrix{T},DVector{Int}} (A, ipiv, 0 )
3129end
30+
31+ function searchmax_pivot! (piv_idx:: AbstractArray{Int} , piv_val:: AbstractArray{T} , A:: AbstractArray{T} , offset:: Int = 0 ) where T
32+ max_idx = argmax (abs .(A[:]))
33+ piv_idx[1 ] = offset+ max_idx
34+ piv_val[1 ] = A[max_idx]
35+ println (" searchmax_pivot: " , piv_idx[1 ], " \n " , abs (piv_val[1 ]))
36+ end
37+
38+ function update_ipiv! (ipivl, piv_idx:: AbstractArray{Int} , piv_val:: AbstractArray{T} , k:: Int , nb:: Int ) where T
39+ max_piv_idx = argmax (abs .(piv_val))
40+ ipivl[1 ] = (max_piv_idx+ k- 2 )* nb + piv_idx[max_piv_idx]
41+ println (" update_ipiv: " , ipivl[1 ])
42+ end
43+
44+ function swaprows_panel! (A:: AbstractArray{T} , M:: AbstractArray{T} , ipivl:: AbstractVector{Int} , m:: Int , p:: Int , nb:: Int ) where T
45+ q = div (ipivl[1 ]- 1 ,nb) + 1
46+ r = (ipivl[1 ]- 1 )% nb+ 1
47+ if m == q
48+ A[p,:], M[r,:] = M[r,:], A[p,:]
49+ println (" swaprows_panel: " , imag .(A[p,:]), " \n " , imag .(M[r,:]))
50+ end
51+ end
52+
53+ function update_panel! (M:: AbstractArray{T} , A:: AbstractArray{T} , p:: Int ) where T
54+ Acinv = one (T) / A[p,p]
55+ LinearAlgebra. BLAS. scal! (Acinv, view (M, :, p))
56+ LinearAlgebra. BLAS. ger! (- one (T), view (M, :, p), conj .(view (A, p, p+ 1 : size (A,2 ))), view (M, :, p+ 1 : size (M,2 )))
57+ end
58+
59+ function swaprows_trail! (A:: AbstractArray{T} , M:: AbstractArray{T} , ipiv:: AbstractVector{Int} , m:: Int , nb:: Int ) where T
60+ for p in eachindex (ipiv)
61+ q = div (ipiv[p]- 1 ,nb) + 1
62+ r = (ipiv[p]- 1 )% nb+ 1
63+ if m == q
64+ A[p,:], M[r,:] = M[r,:], A[p,:]
65+ println (" swaprows_trail: " , imag .(A[p,:]), " \n " , imag .(M[r,:]))
66+ end
67+ end
68+ end
69+
70+ function LinearAlgebra. lu (A:: DMatrix{T} , :: LinearAlgebra.RowMaximum ; check:: Bool = true ) where T
71+ A_copy = LinearAlgebra. _lucopy (A, LinearAlgebra. lutype (T))
72+ return LinearAlgebra. lu! (A_copy, LinearAlgebra. RowMaximum (); check= check)
73+ end
74+ function LinearAlgebra. lu! (A:: DMatrix{T} , :: LinearAlgebra.RowMaximum ; check:: Bool = true ) where T
75+ zone = one (T)
76+ mzone = - one (T)
77+
78+ Ac = A. chunks
79+ mt, nt = size (Ac)
80+ m, n = size (A)
81+ mb, nb = A. partitioning. blocksize
82+
83+ mb != nb && error (" Unequal block sizes are not supported: mb = $mb , nb = $nb " )
84+
85+ ipiv = DVector (collect (1 : min (m, n)), Blocks (mb))
86+ ipivc = ipiv. chunks
87+
88+ max_piv_idx = zeros (Int,mt)
89+ max_piv_val = zeros (T, mt)
90+
91+ Dagger. spawn_datadeps () do
92+ for k in 1 : min (mt, nt)
93+ for p in 1 : min (nb, m- (k- 1 )* nb, n- (k- 1 )* nb)
94+ Dagger. @spawn searchmax_pivot! (Out (view (max_piv_idx, k: k)), Out (view (max_piv_val, k: k)), In (view (Ac[k,k],p: min (nb,m- (k- 1 )* nb),p: p)), p- 1 )
95+ for i in k+ 1 : mt
96+ Dagger. @spawn searchmax_pivot! (Out (view (max_piv_idx, i: i)), Out (view (max_piv_val, i: i)), In (view (Ac[i,k],:,p: p)))
97+ end
98+ Dagger. @spawn update_ipiv! (InOut (view (ipivc[k],p: p)), In (view (max_piv_idx, k: mt)), In (view (max_piv_val, k: mt)), k, nb)
99+ for i in k: mt
100+ Dagger. @spawn swaprows_panel! (InOut (Ac[k, k]), InOut (Ac[i, k]), InOut (view (ipivc[k],p: p)), i, p, nb)
101+ end
102+ Dagger. @spawn update_panel! (InOut (view (Ac[k,k],p+ 1 : min (nb,m- (k- 1 )* nb),:)), In (Ac[k,k]), p)
103+ for i in k+ 1 : mt
104+ Dagger. @spawn update_panel! (InOut (Ac[i, k]), In (Ac[k,k]), p)
105+ end
106+
107+ end
108+ for j in Iterators. flatten ((1 : k- 1 , k+ 1 : nt))
109+ for i in k: mt
110+ Dagger. @spawn swaprows_trail! (InOut (Ac[k, j]), InOut (Ac[i, j]), In (ipivc[k]), i, mb)
111+ end
112+ end
113+ for j in k+ 1 : nt
114+ Dagger. @spawn BLAS. trsm! (' L' , ' L' , ' N' , ' U' , zone, In (Ac[k, k]), InOut (Ac[k, j]))
115+ for i in k+ 1 : mt
116+ Dagger. @spawn BLAS. gemm! (' N' , ' N' , mzone, In (Ac[i, k]), In (Ac[k, j]), zone, InOut (Ac[i, j]))
117+ end
118+ end
119+ end
120+ end
121+
122+ return LinearAlgebra. LU {T,DMatrix{T},DVector{Int}} (A, ipiv, 0 )
123+ end
0 commit comments