diff --git a/src/generic.jl b/src/generic.jl index 2b03b249..49195585 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -280,6 +280,7 @@ julia> rmul!([NaN], 0.0) ``` """ function rmul!(X::AbstractArray, s::Number) + isone(s) && return X @simd for I in eachindex(X) @inbounds X[I] *= s end @@ -318,6 +319,7 @@ julia> lmul!(0.0, [Inf]) ``` """ function lmul!(s::Number, X::AbstractArray) + isone(s) && return X @simd for I in eachindex(X) @inbounds X[I] = s*X[I] end diff --git a/src/matmul.jl b/src/matmul.jl index a7838df8..58a1609e 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -488,14 +488,13 @@ end # THE one big BLAS dispatch. This is split into two methods to improve latency Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat} + α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) matmul_size_check(size(C), (mA, nA), (mB, nB)) return _rmul_or_fill!(C, β) end - matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val) return C end @@ -570,6 +569,70 @@ Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha, _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) end +""" + generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} + +Computes syrk/herk for generic number types. If `conjugate` is false computes syrk, i.e., +``A transpose(A) α + C β`` if `aat` is true, and ``transpose(A) A α + C β`` otherwise. +If `conjugate` is true computes herk, i.e., ``A A' α + C β`` if `aat` is true, and +``A' A α + C β`` otherwise. Only the upper triangular is computed. +""" +function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number} + require_one_based_indexing(C, A) + nC = checksquare(C) + m, n = size(A, 1), size(A, 2) + mA = aat ? m : n + if nC != mA + throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))")) + end + + _rmul_or_fill!(C, β) + @inbounds if !conjugate + if aat + for k ∈ 1:n, j ∈ 1:m + αA_jk = A[j, k] * α + for i ∈ 1:j + C[i, j] += A[i, k] * αA_jk + end + end + else + for j ∈ 1:n, i ∈ 1:j + temp = A[1, i] * A[1, j] + for k ∈ 2:m + temp += A[k, i] * A[k, j] + end + C[i, j] += temp * α + end + end + else + if aat + for k ∈ 1:n, j ∈ 1:m + αA_jk_bar = conj(A[j, k]) * α + for i ∈ 1:j-1 + C[i, j] += A[i, k] * αA_jk_bar + end + C[j, j] += abs2(A[j, k]) * α + end + else + for j ∈ 1:n + for i ∈ 1:j-1 + temp = conj(A[1, i]) * A[1, j] + for k ∈ 2:m + temp += conj(A[k, i]) * A[k, j] + end + C[i, j] += temp * α + end + temp = abs2(A[1, j]) + for k ∈ 2:m + temp += abs2(A[k, j]) + end + C[j, j] += temp * α + end + end + end + return C +end + # legacy method Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = @@ -713,12 +776,27 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(C, 1) == 1 && - _fullstride2(A) && _fullstride2(C)) - return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U') + _fullstride2(A) && _fullstride2(C)) && + max(nA, mA) ≥ 4 + BLAS.syrk!('U', tA, alpha, A, beta, C) + else + generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta) end + return copytri!(C, 'U') end return gemm_wrapper!(C, tA, tAt, A, A, α, β) end +Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, + α::Number, β::Number) where {T<:Number} + + tA_uc = uppercase(tA) # potentially strip a WrapperChar + aat = (tA_uc == 'N') + if T <: Union{Real,Complex} && (iszero(β) || issymmetric(C)) + return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U') + end + tAt = aat ? 'T' : 'N' + return _generic_matmatmul!(C, wrap(A, tA), wrap(A, tAt), α, β) +end # legacy method syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = syrk_wrapper!(C, tA, A, _add.alpha, _add.beta) @@ -748,12 +826,27 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(C, 1) == 1 && - _fullstride2(A) && _fullstride2(C)) - return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) + _fullstride2(A) && _fullstride2(C)) && + max(nA, mA) ≥ 4 + BLAS.herk!('U', tA, alpha, A, beta, C) + else + generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta) end + return copytri!(C, 'U', true) end return gemm_wrapper!(C, tA, tAt, A, A, α, β) end +Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, + α::Number, β::Number) where {T<:Number} + + tA_uc = uppercase(tA) # potentially strip a WrapperChar + aat = (tA_uc == 'N') + if isreal(α) && isreal(β) && (iszero(β) || ishermitian(C)) + return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true) + end + tAt = aat ? 'C' : 'N' + return _generic_matmatmul!(C, wrap(A, tA), wrap(A, tAt), α, β) +end # legacy method herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = @@ -787,6 +880,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab mB, nB = lapack_size(tB, B) matmul_size_check(size(C), (mA, nA), (mB, nB)) + matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C if C === A || B === C throw(ArgumentError("output matrix must not be aliased with input matrix")) @@ -805,6 +899,13 @@ end gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = gemm_wrapper!(C, tA, tB, A, B, _add.alpha, _add.beta) +# fallback for generic types +Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, + A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + α::Number, β::Number) where {T<:Number} + matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C + return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β) +end # Aggressive constprop helps propagate the values of tA and tB into wrap, which # makes the calls concretely inferred diff --git a/test/generic.jl b/test/generic.jl index 6d11ec82..dc9eabc8 100644 --- a/test/generic.jl +++ b/test/generic.jl @@ -123,6 +123,42 @@ end @test_throws DimensionMismatch axpy!(α, x, Vector(1:3), y, Vector(1:5)) end +@testset "generic syrk & herk" begin + for T ∈ (BigFloat, Complex{BigFloat}, Quaternion{Float64}) + α = randn(T) + a = randn(T, 3, 4) + csmall = similar(a, 3, 3) + csmall_fallback = similar(a, 3, 3) + cbig = similar(a, 4, 4) + cbig_fallback = similar(a, 4, 4) + mul!(csmall, a, a', real(α), false) + LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', real(α), false) + @test ishermitian(csmall) + @test csmall ≈ csmall_fallback + mul!(cbig, a', a, real(α), false) + LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, real(α), false) + @test ishermitian(cbig) + @test cbig ≈ cbig_fallback + mul!(csmall, a, transpose(a), α, false) + LinearAlgebra._generic_matmatmul!(csmall_fallback, a, transpose(a), α, false) + @test csmall ≈ csmall_fallback + mul!(cbig, transpose(a), a, α, false) + LinearAlgebra._generic_matmatmul!(cbig_fallback, transpose(a), a, α, false) + @test cbig ≈ cbig_fallback + if T <: Union{Real, Complex} + @test issymmetric(csmall) + @test issymmetric(cbig) + end + #make sure generic herk is not called for non-real α + mul!(csmall, a, a', α, false) + LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', α, false) + @test csmall ≈ csmall_fallback + mul!(cbig, a', a, α, false) + LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, α, false) + @test cbig ≈ cbig_fallback + end +end + @test !issymmetric(fill(1,5,3)) @test !ishermitian(fill(1,5,3)) @test (x = fill(1,3); cross(x,x) == zeros(3))