Skip to content

Commit a478c1f

Browse files
authored
Merge pull request #27 from YingboMa/myb/1.7
1.7 compat
2 parents b14bb60 + 1c97106 commit a478c1f

File tree

4 files changed

+39
-17
lines changed

4 files changed

+39
-17
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@ on:
88
- master
99
jobs:
1010
test:
11-
runs-on: ubuntu-latest
11+
runs-on: ${{ matrix.os }}
12+
strategy:
13+
matrix:
14+
julia-version: ['1', '^1.7.0-0']
15+
os: [ubuntu-latest, windows-latest, macOS-latest]
1216
steps:
1317
- uses: actions/checkout@v2
1418
- uses: julia-actions/setup-julia@v1
1519
with:
16-
version: 1
20+
version: ${{ matrix.julia-version }}
1721
- uses: actions/cache@v1
1822
env:
1923
cache-name: cache-artifacts

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveFactorization"
22
uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
33
authors = ["Yingbo Ma <[email protected]>"]
4-
version = "0.1.12"
4+
version = "0.1.13"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/lu.jl

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
using LoopVectorization
22
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, checknonsingular, BLAS, LinearAlgebra
33

4-
function lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true); kwargs...)
5-
return lu!(copy(A), pivot; kwargs...)
4+
# 1.7 compat
5+
normalize_pivot(t::Val{T}) where T = t
6+
to_stdlib_pivot(t) = t
7+
if VERSION >= v"1.7.0-DEV.1188"
8+
normalize_pivot(::LinearAlgebra.RowMaximum) = Val(true)
9+
normalize_pivot(::LinearAlgebra.NoPivot) = Val(false)
10+
to_stdlib_pivot(::Val{true}) = LinearAlgebra.RowMaximum()
11+
to_stdlib_pivot(::Val{false}) = LinearAlgebra.NoPivot()
612
end
713

8-
function lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true); check=true, kwargs...)
14+
function lu(A::AbstractMatrix, pivot = Val(true); kwargs...)
15+
return lu!(copy(A), normalize_pivot(pivot); kwargs...)
16+
end
17+
18+
function lu!(A, pivot = Val(true); check=true, kwargs...)
919
m, n = size(A)
1020
minmn = min(m, n)
1121
F = if minmn < 10 # avx introduces small performance degradation
12-
LinearAlgebra.generic_lufact!(A, pivot; check=check)
22+
LinearAlgebra.generic_lufact!(A, to_stdlib_pivot(pivot); check=check)
1323
else
14-
lu!(A, Vector{BlasInt}(undef, minmn), pivot; check=check, kwargs...)
24+
lu!(A, Vector{BlasInt}(undef, minmn), normalize_pivot(pivot); check=check, kwargs...)
1525
end
1626
return F
1727
end
@@ -20,21 +30,28 @@ const RECURSION_THRESHOLD = Ref(-1)
2030

2131
# AVX512 needs a smaller recursion limit
2232
function pick_threshold()
23-
blasvendor = BLAS.vendor()
2433
RECURSION_THRESHOLD[] >= 0 && return RECURSION_THRESHOLD[]
34+
blasvendor = @static if VERSION >= v"1.7.0-DEV.610"
35+
:openblas64
36+
else
37+
BLAS.vendor()
38+
end
2539
if blasvendor === :openblas || blasvendor === :openblas64
2640
LoopVectorization.register_size() == 64 ? 110 : 72
2741
else
2842
LoopVectorization.register_size() == 64 ? 48 : 72
2943
end
3044
end
3145

32-
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
33-
pivot::Union{Val{false}, Val{true}} = Val(true);
34-
check::Bool=true,
35-
# the performance is not sensitive wrt blocksize, and 16 is a good default
36-
blocksize::Integer=16,
37-
threshold::Integer=pick_threshold()) where T
46+
function lu!(
47+
A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
48+
pivot = Val(true);
49+
check::Bool=true,
50+
# the performance is not sensitive wrt blocksize, and 16 is a good default
51+
blocksize::Integer=16,
52+
threshold::Integer=pick_threshold()
53+
) where T
54+
pivot = normalize_pivot(pivot)
3855
info = zero(BlasInt)
3956
m, n = size(A)
4057
mnmin = min(m, n)
@@ -187,7 +204,7 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
187204
elseif info == 0
188205
info = k
189206
end
190-
k == minmn && break
207+
k == minmn && break
191208
# Update the rest
192209
@avx for j = k+1:n
193210
for i = k+1:m

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ end
1818
@testset "Test LU factorization" begin
1919
for _p in (true, false), T in (Float64, Float32, ComplexF64, ComplexF32, Real)
2020
p = Val(_p)
21-
for s in [1:10; 50:80:200; 300]
21+
for (i, s) in enumerate([1:10; 50:80:200; 300])
22+
iseven(i) && (p = RecursiveFactorization.to_stdlib_pivot(p))
2223
siz = (s, s+2)
2324
@info("size: $(siz[1]) × $(siz[2]), T = $T, p = $_p")
2425
if isconcretetype(T)

0 commit comments

Comments
 (0)