@@ -7,8 +7,6 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=t
7
7
mzone = - one (T)
8
8
Ac = A. chunks
9
9
mt, nt = size (Ac)
10
- iscomplex = T <: Complex
11
- trans = iscomplex ? ' C' : ' T'
12
10
13
11
Dagger. spawn_datadeps () do
14
12
for k in range (1 , min (mt, nt))
@@ -29,3 +27,97 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=t
29
27
30
28
return LinearAlgebra. LU {T,DMatrix{T},DVector{Int}} (A, ipiv, 0 )
31
29
end
30
+
31
+ function searchmax_pivot! (piv_idx:: AbstractArray{Int} , piv_val:: AbstractArray{T} , A:: AbstractArray{T} , offset:: Int = 0 ) where T
32
+ max_idx = argmax (abs .(A[:]))
33
+ piv_idx[1 ] = offset+ max_idx
34
+ piv_val[1 ] = A[max_idx]
35
+ println (" searchmax_pivot: " , piv_idx[1 ], " \n " , abs (piv_val[1 ]))
36
+ end
37
+
38
+ function update_ipiv! (ipivl, piv_idx:: AbstractArray{Int} , piv_val:: AbstractArray{T} , k:: Int , nb:: Int ) where T
39
+ max_piv_idx = argmax (abs .(piv_val))
40
+ ipivl[1 ] = (max_piv_idx+ k- 2 )* nb + piv_idx[max_piv_idx]
41
+ println (" update_ipiv: " , ipivl[1 ])
42
+ end
43
+
44
+ function swaprows_panel! (A:: AbstractArray{T} , M:: AbstractArray{T} , ipivl:: AbstractVector{Int} , m:: Int , p:: Int , nb:: Int ) where T
45
+ q = div (ipivl[1 ]- 1 ,nb) + 1
46
+ r = (ipivl[1 ]- 1 )% nb+ 1
47
+ if m == q
48
+ A[p,:], M[r,:] = M[r,:], A[p,:]
49
+ println (" swaprows_panel: " , imag .(A[p,:]), " \n " , imag .(M[r,:]))
50
+ end
51
+ end
52
+
53
+ function update_panel! (M:: AbstractArray{T} , A:: AbstractArray{T} , p:: Int ) where T
54
+ Acinv = one (T) / A[p,p]
55
+ LinearAlgebra. BLAS. scal! (Acinv, view (M, :, p))
56
+ LinearAlgebra. BLAS. ger! (- one (T), view (M, :, p), conj .(view (A, p, p+ 1 : size (A,2 ))), view (M, :, p+ 1 : size (M,2 )))
57
+ end
58
+
59
+ function swaprows_trail! (A:: AbstractArray{T} , M:: AbstractArray{T} , ipiv:: AbstractVector{Int} , m:: Int , nb:: Int ) where T
60
+ for p in eachindex (ipiv)
61
+ q = div (ipiv[p]- 1 ,nb) + 1
62
+ r = (ipiv[p]- 1 )% nb+ 1
63
+ if m == q
64
+ A[p,:], M[r,:] = M[r,:], A[p,:]
65
+ println (" swaprows_trail: " , imag .(A[p,:]), " \n " , imag .(M[r,:]))
66
+ end
67
+ end
68
+ end
69
+
70
+ function LinearAlgebra. lu (A:: DMatrix{T} , :: LinearAlgebra.RowMaximum ; check:: Bool = true ) where T
71
+ A_copy = LinearAlgebra. _lucopy (A, LinearAlgebra. lutype (T))
72
+ return LinearAlgebra. lu! (A_copy, LinearAlgebra. RowMaximum (); check= check)
73
+ end
74
+ function LinearAlgebra. lu! (A:: DMatrix{T} , :: LinearAlgebra.RowMaximum ; check:: Bool = true ) where T
75
+ zone = one (T)
76
+ mzone = - one (T)
77
+
78
+ Ac = A. chunks
79
+ mt, nt = size (Ac)
80
+ m, n = size (A)
81
+ mb, nb = A. partitioning. blocksize
82
+
83
+ mb != nb && error (" Unequal block sizes are not supported: mb = $mb , nb = $nb " )
84
+
85
+ ipiv = DVector (collect (1 : min (m, n)), Blocks (mb))
86
+ ipivc = ipiv. chunks
87
+
88
+ max_piv_idx = zeros (Int,mt)
89
+ max_piv_val = zeros (T, mt)
90
+
91
+ Dagger. spawn_datadeps () do
92
+ for k in 1 : min (mt, nt)
93
+ for p in 1 : min (nb, m- (k- 1 )* nb, n- (k- 1 )* nb)
94
+ Dagger. @spawn searchmax_pivot! (Out (view (max_piv_idx, k: k)), Out (view (max_piv_val, k: k)), In (view (Ac[k,k],p: min (nb,m- (k- 1 )* nb),p: p)), p- 1 )
95
+ for i in k+ 1 : mt
96
+ Dagger. @spawn searchmax_pivot! (Out (view (max_piv_idx, i: i)), Out (view (max_piv_val, i: i)), In (view (Ac[i,k],:,p: p)))
97
+ end
98
+ Dagger. @spawn update_ipiv! (InOut (view (ipivc[k],p: p)), In (view (max_piv_idx, k: mt)), In (view (max_piv_val, k: mt)), k, nb)
99
+ for i in k: mt
100
+ Dagger. @spawn swaprows_panel! (InOut (Ac[k, k]), InOut (Ac[i, k]), InOut (view (ipivc[k],p: p)), i, p, nb)
101
+ end
102
+ Dagger. @spawn update_panel! (InOut (view (Ac[k,k],p+ 1 : min (nb,m- (k- 1 )* nb),:)), In (Ac[k,k]), p)
103
+ for i in k+ 1 : mt
104
+ Dagger. @spawn update_panel! (InOut (Ac[i, k]), In (Ac[k,k]), p)
105
+ end
106
+
107
+ end
108
+ for j in Iterators. flatten ((1 : k- 1 , k+ 1 : nt))
109
+ for i in k: mt
110
+ Dagger. @spawn swaprows_trail! (InOut (Ac[k, j]), InOut (Ac[i, j]), In (ipivc[k]), i, mb)
111
+ end
112
+ end
113
+ for j in k+ 1 : nt
114
+ Dagger. @spawn BLAS. trsm! (' L' , ' L' , ' N' , ' U' , zone, In (Ac[k, k]), InOut (Ac[k, j]))
115
+ for i in k+ 1 : mt
116
+ Dagger. @spawn BLAS. gemm! (' N' , ' N' , mzone, In (Ac[i, k]), In (Ac[k, j]), zone, InOut (Ac[i, j]))
117
+ end
118
+ end
119
+ end
120
+ end
121
+
122
+ return LinearAlgebra. LU {T,DMatrix{T},DVector{Int}} (A, ipiv, 0 )
123
+ end
0 commit comments