Skip to content

Commit c51bad0

Browse files
dkarraschamontoison
andcommitted
Use BLAS.trsm! instead of LAPACK.trtrs! in left-triangular solves
Co-authored-by: Alexis Montoison <[email protected]>
1 parent e7da19f commit c51bad0

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/triangular.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,11 +1223,13 @@ function generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function,
12231223
end
12241224
end
12251225
# division
1226-
function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat}
1226+
generic_trimatdiv!(C::StridedVector{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVector{T}) where {T<:BlasFloat} =
1227+
BLAS.trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1228+
function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat}
12271229
if stride(C,1) == stride(A,1) == 1
1228-
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1230+
BLAS.trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
12291231
else # incompatible with LAPACK
1230-
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat)
1232+
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
12311233
end
12321234
end
12331235
function generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat}

test/triangular.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -886,8 +886,13 @@ end
886886
end
887887
end
888888

889-
@testset "(l/r)mul! and (l/r)div! for non-contiguous matrices" begin
889+
@testset "(l/r)mul! and (l/r)div! for non-contiguous arrays" begin
890890
U = UpperTriangular(reshape(collect(3:27.0),5,5))
891+
b = float.(1:10)
892+
b2 = copy(b); b2v = view(b2, 1:2:9); b2vc = copy(b2v)
893+
@test lmul!(U, b2v) == lmul!(U, b2vc)
894+
b2 = copy(b); b2v = view(b2, 1:2:9); b2vc = copy(b2v)
895+
@test ldiv!(U, b2v) ldiv!(U, b2vc)
891896
B = float.(collect(reshape(1:100, 10,10)))
892897
B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v)
893898
@test lmul!(U, B2v) == lmul!(U, B2vc)

0 commit comments

Comments
 (0)