1
1
using LoopVectorization
2
2
using TriangularSolve: ldiv!
3
3
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4
- LinearAlgebra, Adjoint, Transpose
4
+ LinearAlgebra, Adjoint, Transpose, UpperTriangular
5
5
using StrideArraysCore
6
6
using Polyester: @batch
7
7
@@ -41,16 +41,22 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
41
41
42
42
if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_cols! )
43
43
function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
44
- B:: StridedVecOrMat )
44
+ B:: StridedVecOrMat )
45
45
return B
46
46
end
47
47
end
48
48
if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_rows! )
49
49
function LinearAlgebra. _ipiv_rows! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
50
- B:: StridedVecOrMat )
50
+ B:: StridedVecOrMat )
51
51
return B
52
52
end
53
53
end
54
+ if CUSTOMIZABLE_PIVOT
55
+ function LinearAlgebra. ldiv! (A:: LU{T, <:StridedMatrix, <:NotIPIV} ,
56
+ B:: StridedVecOrMat{T} ) where {T <: BlasFloat }
57
+ ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), B))
58
+ end
59
+ end
54
60
55
61
function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
56
62
m, n = size (A)
@@ -80,11 +86,11 @@ recurse(_) = false
80
86
_ptrarray (ipiv) = PtrArray (ipiv)
81
87
_ptrarray (ipiv:: NotIPIV ) = ipiv
82
88
function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
83
- pivot = Val (true ), thread = Val (true );
84
- check:: Bool = true ,
85
- # the performance is not sensitive wrt blocksize, and 8 is a good default
86
- blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
87
- threshold:: Integer = pick_threshold ()) where {T}
89
+ pivot = Val (true ), thread = Val (true );
90
+ check:: Bool = true ,
91
+ # the performance is not sensitive wrt blocksize, and 8 is a good default
92
+ blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
93
+ threshold:: Integer = pick_threshold ()) where {T}
88
94
pivot = normalize_pivot (pivot)
89
95
info = zero (BlasInt)
90
96
m, n = size (A)
@@ -94,10 +100,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
94
100
end
95
101
if recurse (A) && mnmin > threshold
96
102
if T <: Union{Float32, Float64}
97
- GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
98
- m, n, mnmin,
99
- _ptrarray (ipiv), info, blocksize,
100
- thread) end
103
+ GC. @preserve ipiv A begin
104
+ info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
105
+ m, n, mnmin,
106
+ _ptrarray (ipiv), info, blocksize,
107
+ thread)
108
+ end
101
109
else
102
110
info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
103
111
end
@@ -109,7 +117,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
109
117
end
110
118
111
119
@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
112
- :: Val{true} ) where {Pivot}
120
+ :: Val{true} ) where {Pivot}
113
121
if length (A) * _sizeof (eltype (A)) >
114
122
0.92 * LoopVectorization. VectorizationBase. cache_size (Val (2 ))
115
123
_recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
@@ -118,11 +126,11 @@ end
118
126
end
119
127
end
120
128
@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
121
- :: Val{false} ) where {Pivot}
129
+ :: Val{false} ) where {Pivot}
122
130
_recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
123
131
end
124
132
@inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
125
- :: Val{Thread} ) where {Pivot, Thread}
133
+ :: Val{Thread} ) where {Pivot, Thread}
126
134
info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread)):: Int
127
135
@inbounds if m < n # fat matrix
128
136
# [AL AR]
@@ -166,7 +174,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
166
174
nothing
167
175
end
168
176
function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize,
169
- thread):: BlasInt where {T, Pivot}
177
+ thread):: BlasInt where {T, Pivot}
170
178
@inbounds begin
171
179
if n <= max (blocksize, 1 )
172
180
info = _generic_lufact! (A, Val (Pivot), ipiv, info)
@@ -262,44 +270,46 @@ end
262
270
function _generic_lufact! (A, :: Val{Pivot} , ipiv, info) where {Pivot}
263
271
m, n = size (A)
264
272
minmn = length (ipiv)
265
- @inbounds begin for k in 1 : minmn
266
- # find index max
267
- kp = k
268
- if Pivot
269
- amax = abs (zero (eltype (A)))
270
- for i in k: m
271
- absi = abs (A[i, k])
272
- if absi > amax
273
- kp = i
274
- amax = absi
273
+ @inbounds begin
274
+ for k in 1 : minmn
275
+ # find index max
276
+ kp = k
277
+ if Pivot
278
+ amax = abs (zero (eltype (A)))
279
+ for i in k: m
280
+ absi = abs (A[i, k])
281
+ if absi > amax
282
+ kp = i
283
+ amax = absi
284
+ end
275
285
end
286
+ ipiv[k] = kp
276
287
end
277
- ipiv[k] = kp
278
- end
279
- if ! iszero (A[kp, k])
280
- if k != kp
281
- # Interchange
282
- @simd for i in 1 : n
283
- tmp = A[k, i]
284
- A[k, i] = A[kp, i]
285
- A[kp, i] = tmp
288
+ if ! iszero (A[kp, k])
289
+ if k != kp
290
+ # Interchange
291
+ @simd for i in 1 : n
292
+ tmp = A[k, i]
293
+ A[k, i] = A[kp, i]
294
+ A[kp, i] = tmp
295
+ end
286
296
end
297
+ # Scale first column
298
+ Akkinv = inv (A[k, k])
299
+ @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
300
+ A[i, k] *= Akkinv
301
+ end
302
+ elseif info == 0
303
+ info = k
287
304
end
288
- # Scale first column
289
- Akkinv = inv (A[k, k])
290
- @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
291
- A[i, k] *= Akkinv
292
- end
293
- elseif info == 0
294
- info = k
295
- end
296
- k == minmn && break
297
- # Update the rest
298
- @turbo warn_check_args= false for j in (k + 1 ): n
299
- for i in (k + 1 ): m
300
- A[i, j] -= A[i, k] * A[k, j]
305
+ k == minmn && break
306
+ # Update the rest
307
+ @turbo warn_check_args= false for j in (k + 1 ): n
308
+ for i in (k + 1 ): m
309
+ A[i, j] -= A[i, k] * A[k, j]
310
+ end
301
311
end
302
312
end
303
- end end
313
+ end
304
314
return info
305
315
end
0 commit comments