Skip to content

Commit 20926be

Browse files
authored
Merge pull request #3 from YingboMa/myb/generic
Make LU more generic
2 parents ed26db8 + 5973687 commit 20926be

File tree

4 files changed

+33
-12
lines changed

4 files changed

+33
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
*.jl.cov
22
*.jl.*.cov
33
*.jl.mem
4+
Manifest.toml

Project.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
name = "RecursiveFactorization"
2+
uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
3+
authors = ["Yingbo Ma <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
9+
[compat]
10+
julia = "1"
11+
12+
[extras]
13+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
16+
[targets]
17+
test = ["Test", "Random"]

src/lu.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1-
using LinearAlgebra: BlasInt, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular
1+
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, ldiv!, BLAS, checknonsingular
22

3-
lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true);
4-
check::Bool = true, blocksize::Integer = 16) = lu!(copy(A), pivot;
5-
check = check, blocksize = blocksize)
3+
function lu(A::AbstractMatrix, pivot::Union{Val{false}, Val{true}} = Val(true);
4+
check::Bool = true, blocksize::Integer = 16)
5+
lu!(copy(A), pivot; check = check, blocksize = blocksize)
6+
end
67

7-
lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true);
8-
check::Bool = true, blocksize::Integer = 16) = lu!(copy(A), Vector{BlasInt}(undef, min(size(A)...)), pivot;
9-
check = check, blocksize = blocksize)
8+
function lu!(A, pivot::Union{Val{false}, Val{true}} = Val(true);
9+
check::Bool = true, blocksize::Integer = 16)
10+
lu!(A, Vector{BlasInt}(undef, min(size(A)...)), pivot;
11+
check = check, blocksize = blocksize)
12+
end
1013

1114
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
1215
pivot::Union{Val{false}, Val{true}} = Val(true);
1316
check::Bool=true, blocksize::Integer=16) where T
1417
info = Ref(zero(BlasInt))
1518
m, n = size(A)
1619
mnmin = min(m, n)
17-
if isconcretetype(T)
20+
if T <: BlasFloat && A isa StridedArray
1821
reckernel!(A, pivot, m, mnmin, ipiv, info, blocksize)
1922
if m < n # fat matrix
2023
# [AL AR]
@@ -115,15 +118,15 @@ end
115118
Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
116119
License is MIT: https://julialang.org/license
117120
=#
118-
function _generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot}, ipiv, info) where {Pivot,T}
121+
function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
119122
m, n = size(A)
120123
minmn = length(ipiv)
121124
@inbounds begin
122125
for k = 1:minmn
123126
# find index max
124127
kp = k
125128
if Pivot
126-
amax = abs(zero(T))
129+
amax = abs(zero(eltype(A)))
127130
for i = k:m
128131
absi = abs(A[i,k])
129132
if absi > amax

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ using Random
66

77
Random.seed!(12)
88

9-
baselu = LinearAlgebra.lu
10-
mylu = RecursiveFactorization.lu
9+
const baselu = LinearAlgebra.lu
10+
const mylu = RecursiveFactorization.lu
1111

1212
function testlu(A, MF, BF)
1313
@test MF.info == BF.info

0 commit comments

Comments
 (0)