From 619ab3b145c0b0980c4c0c7d522e67a9f7b1ea58 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 28 Feb 2024 19:06:18 +0100 Subject: [PATCH 1/2] Adapt to new LinearAlgebra.generic_*mul! interface --- src/host/linalg.jl | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index a2a99019..27b379e7 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -338,8 +338,10 @@ end ## matrix multiplication - -function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, a::Number, b::Number) where {T,S,R} +# legacy method +generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) where {T,S,R} = + generic_matmatmul!(C, A, B, MulAddMul(a, b)) +function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R} if size(A,2) != size(B,1) throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) end @@ -350,8 +352,6 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac return fill!(C, zero(R)) end - add = MulAddMul(a, b) - gpu_call(C, A, B; name="matmatmul!") do ctx, C, A, B idx = @linearidx C assume.(size(C) .> 0) @@ -372,42 +372,52 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac C end +if VERSION < v"1.12.0-" function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul()) - generic_matmatmul!(C, wrap(A, tA), B, _add.alpha, _add.beta) + generic_matmatmul!(C, wrap(A, tA), B, _add) end function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul()) - generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta) + generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) +end +else +function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number) + LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), B, MulAddMul(a, b)) +end + +function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number) + LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(a, b)) end +end if VERSION < v"1.10.0-DEV.1365" # catch other functions that are called by LinearAlgebra's mul! function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number) - generic_matmatmul!(C, wrap(A, tA), B, a, b) + generic_matmatmul!(C, wrap(A, tA), B, MulAddMul(a, b)) end # disambiguation function LinearAlgebra.gemv!(C::AbstractGPUVector{T}, tA::AbstractChar, A::AbstractGPUMatrix{T}, B::AbstractGPUVector{T}, a::Number, b::Number) where {T<:LinearAlgebra.BlasFloat} - generic_matmatmul!(C, wrap(A, tA), B, a, b) + generic_matmatmul!(C, wrap(A, tA), B, MulAddMul(a, b)) end LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul) = - LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add) + generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) # disambiguation LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, _add::MulAddMul) where {T<:LinearAlgebra.BlasFloat} = - LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add) + generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) function LinearAlgebra.syrk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul()) if tA == 'T' - LinearAlgebra.generic_matmatmul!(C, 'T', 'N', A, A, _add) + generic_matmatmul!(C, wrap(A, 'T'), A, _add) else # tA == 'N' - LinearAlgebra.generic_matmatmul!(C, 'N', 'T', A, A, _add) + generic_matmatmul!(C, A, wrap(A, 'T'), _add) end end function LinearAlgebra.herk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul()) if tA == 'C' - LinearAlgebra.generic_matmatmul!(C, 'C', 'N', A, A, _add) + generic_matmatmul!(C, wrap(A, 'C'), A, _add) else # tA == 'N' - LinearAlgebra.generic_matmatmul!(C, 'N', 'C', A, A, _add) + generic_matmatmul!(C, A, wrap(A, 'C'), _add) end end end # VERSION From 11229a87f246571993b950d7ffa40ddba65ea006 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 23 May 2024 09:59:14 +0200 Subject: [PATCH 2/2] Make version check static. --- src/host/linalg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 27b379e7..255b8a09 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -372,7 +372,7 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac C end -if VERSION < v"1.12.0-" +@static if VERSION < v"1.12.0-" function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul()) generic_matmatmul!(C, wrap(A, tA), B, _add) end