diff --git a/src/triangular.jl b/src/triangular.jl index b602e082..8b31942b 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -1205,15 +1205,35 @@ end # multiplication generic_trimatmul!(c::StridedVector{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, b::AbstractVector{T}) where {T<:BlasFloat} = BLAS.trmv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, c === b ? c : copyto!(c, b)) -generic_trimatmul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat} = - BLAS.trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) -generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat} = - BLAS.trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) +function generic_trimatmul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat} + if stride(C,1) == stride(A,1) == 1 + BLAS.trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) + else # incompatible with BLAS + @invoke generic_trimatmul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix) + end +end +function generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat} + if stride(C,1) == stride(B,1) == 1 + BLAS.trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) + else # incompatible with BLAS + @invoke generic_mattrimul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix) + end +end # division -generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat} = - LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B)) -generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat} = - BLAS.trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) +function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat} + if stride(C,1) == stride(A,1) == 1 + LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B)) + else # incompatible with LAPACK + @invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat) + end +end +function generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat} + if stride(C,1) == stride(B,1) == 1 + BLAS.trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) + else # incompatible with BLAS + @invoke generic_mattridiv!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix) + end +end errorbounds(A::AbstractTriangular{T}, X::AbstractVecOrMat{T}, B::AbstractVecOrMat{T}) where {T<:Union{BigFloat,Complex{BigFloat}}} = error("not implemented yet! Please submit a pull request.") diff --git a/test/triangular.jl b/test/triangular.jl index e1f1808b..c0a9ecf5 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -1427,4 +1427,17 @@ end end end +@testset "(l/r)mul! and (l/r)div! for non-contiguous matrices" begin + U = UpperTriangular(reshape(collect(3:27.0),5,5)) + B = float.(collect(reshape(1:100, 10,10))) + B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v) + @test lmul!(U, B2v) == lmul!(U, B2vc) + B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v) + @test rmul!(B2v, U) == rmul!(B2vc, U) + B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v) + @test ldiv!(U, B2v) ≈ ldiv!(U, B2vc) + B2 = copy(B); B2v = view(B2, 1:2:9, 1:5); B2vc = copy(B2v) + @test rdiv!(B2v, U) ≈ rdiv!(B2vc, U) +end + end # module TestTriangular