1
1
using LoopVectorization
2
2
using TriangularSolve: ldiv!
3
3
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4
- LinearAlgebra, Adjoint, Transpose
4
+ LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
5
5
using StrideArraysCore
6
6
using Polyester: @batch
7
7
@@ -41,16 +41,23 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
41
41
42
42
if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_cols! )
43
43
function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
44
- B:: StridedVecOrMat )
44
+ B:: StridedVecOrMat )
45
45
return B
46
46
end
47
47
end
48
48
if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_rows! )
49
- function LinearAlgebra. _ipiv_rows! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
50
- B:: StridedVecOrMat )
49
+ function LinearAlgebra. _ipiv_rows! (:: (LU{T, <:AbstractMatrix{T}, NotIPIV} where {T}) ,
50
+ :: OrdinalRange ,
51
+ B:: StridedVecOrMat )
51
52
return B
52
53
end
53
54
end
55
+ if CUSTOMIZABLE_PIVOT
56
+ function LinearAlgebra. ldiv! (A:: LU{T, <:StridedMatrix, <:NotIPIV} ,
57
+ B:: StridedVecOrMat{T} ) where {T <: BlasFloat }
58
+ ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), B))
59
+ end
60
+ end
54
61
55
62
function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
56
63
m, n = size (A)
@@ -80,11 +87,11 @@ recurse(_) = false
80
87
_ptrarray (ipiv) = PtrArray (ipiv)
81
88
_ptrarray (ipiv:: NotIPIV ) = ipiv
82
89
function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
83
- pivot = Val (true ), thread = Val (true );
84
- check:: Bool = true ,
85
- # the performance is not sensitive wrt blocksize, and 8 is a good default
86
- blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
87
- threshold:: Integer = pick_threshold ()) where {T}
90
+ pivot = Val (true ), thread = Val (true );
91
+ check:: Bool = true ,
92
+ # the performance is not sensitive wrt blocksize, and 8 is a good default
93
+ blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
94
+ threshold:: Integer = pick_threshold ()) where {T}
88
95
pivot = normalize_pivot (pivot)
89
96
info = zero (BlasInt)
90
97
m, n = size (A)
@@ -94,10 +101,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
94
101
end
95
102
if recurse (A) && mnmin > threshold
96
103
if T <: Union{Float32, Float64}
97
- GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
98
- m, n, mnmin,
99
- _ptrarray (ipiv), info, blocksize,
100
- thread) end
104
+ GC. @preserve ipiv A begin
105
+ info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
106
+ m, n, mnmin,
107
+ _ptrarray (ipiv), info, blocksize,
108
+ thread)
109
+ end
101
110
else
102
111
info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
103
112
end
@@ -109,7 +118,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
109
118
end
110
119
111
120
@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
112
- :: Val{true} ) where {Pivot}
121
+ :: Val{true} ) where {Pivot}
113
122
if length (A) * _sizeof (eltype (A)) >
114
123
0.92 * LoopVectorization. VectorizationBase. cache_size (Val (2 ))
115
124
_recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
@@ -118,11 +127,11 @@ end
118
127
end
119
128
end
120
129
@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
121
- :: Val{false} ) where {Pivot}
130
+ :: Val{false} ) where {Pivot}
122
131
_recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
123
132
end
124
133
@inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
125
- :: Val{Thread} ) where {Pivot, Thread}
134
+ :: Val{Thread} ) where {Pivot, Thread}
126
135
info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread)):: Int
127
136
@inbounds if m < n # fat matrix
128
137
# [AL AR]
@@ -166,7 +175,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
166
175
nothing
167
176
end
168
177
function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize,
169
- thread):: BlasInt where {T, Pivot}
178
+ thread):: BlasInt where {T, Pivot}
170
179
@inbounds begin
171
180
if n <= max (blocksize, 1 )
172
181
info = _generic_lufact! (A, Val (Pivot), ipiv, info)
@@ -262,44 +271,46 @@ end
262
271
function _generic_lufact! (A, :: Val{Pivot} , ipiv, info) where {Pivot}
263
272
m, n = size (A)
264
273
minmn = length (ipiv)
265
- @inbounds begin for k in 1 : minmn
266
- # find index max
267
- kp = k
268
- if Pivot
269
- amax = abs (zero (eltype (A)))
270
- for i in k: m
271
- absi = abs (A[i, k])
272
- if absi > amax
273
- kp = i
274
- amax = absi
274
+ @inbounds begin
275
+ for k in 1 : minmn
276
+ # find index max
277
+ kp = k
278
+ if Pivot
279
+ amax = abs (zero (eltype (A)))
280
+ for i in k: m
281
+ absi = abs (A[i, k])
282
+ if absi > amax
283
+ kp = i
284
+ amax = absi
285
+ end
275
286
end
287
+ ipiv[k] = kp
276
288
end
277
- ipiv[k] = kp
278
- end
279
- if ! iszero (A[kp, k])
280
- if k != kp
281
- # Interchange
282
- @simd for i in 1 : n
283
- tmp = A[k, i]
284
- A[k, i] = A[kp, i]
285
- A[kp, i] = tmp
289
+ if ! iszero (A[kp, k])
290
+ if k != kp
291
+ # Interchange
292
+ @simd for i in 1 : n
293
+ tmp = A[k, i]
294
+ A[k, i] = A[kp, i]
295
+ A[kp, i] = tmp
296
+ end
286
297
end
298
+ # Scale first column
299
+ Akkinv = inv (A[k, k])
300
+ @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
301
+ A[i, k] *= Akkinv
302
+ end
303
+ elseif info == 0
304
+ info = k
287
305
end
288
- # Scale first column
289
- Akkinv = inv (A[k, k])
290
- @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
291
- A[i, k] *= Akkinv
292
- end
293
- elseif info == 0
294
- info = k
295
- end
296
- k == minmn && break
297
- # Update the rest
298
- @turbo warn_check_args= false for j in (k + 1 ): n
299
- for i in (k + 1 ): m
300
- A[i, j] -= A[i, k] * A[k, j]
306
+ k == minmn && break
307
+ # Update the rest
308
+ @turbo warn_check_args= false for j in (k + 1 ): n
309
+ for i in (k + 1 ): m
310
+ A[i, j] -= A[i, k] * A[k, j]
311
+ end
301
312
end
302
313
end
303
- end end
314
+ end
304
315
return info
305
316
end
0 commit comments