You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
253: Fix 5 arg mul! r=haampie a=haampie
Move from 3-arg mul to 5-arg mul to make sure GPUArrays's generic mul does not steal from CuArrays.
(Btw, I could not yet test because cyclops was unreachable :( )
Co-authored-by: Harmen Stoppels <[email protected]>
Copy file name to clipboardExpand all lines: src/host/linalg.jl
+29-11Lines changed: 29 additions & 11 deletions
Original file line number
Diff line number
Diff line change
@@ -66,7 +66,7 @@ end
66
66
67
67
## matrix multiplication
68
68
69
-
functiongeneric_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R}
69
+
functiongeneric_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{S}, a::Number, b::Number) where {T,S,R}
70
70
ifsize(A,2) !=size(B,1)
71
71
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
72
72
end
@@ -92,7 +92,7 @@ function generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}
92
92
for k in1:size(A,2)
93
93
Ctmp += A[i, k]*B[k, j]
94
94
end
95
-
C[i,j] = Ctmp
95
+
C[i,j] = Ctmp*a + C[i,j]*b
96
96
end
97
97
98
98
return
@@ -101,15 +101,33 @@ function generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}
101
101
C
102
102
end
103
103
104
-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat) =generic_matmatmul!(C, A, B)
105
-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}) =generic_matmatmul!(C, A, B)
106
-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}) =generic_matmatmul!(C, A, B)
107
-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat) =generic_matmatmul!(C, A, B)
108
-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat) =generic_matmatmul!(C, A, B)
109
-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}) =generic_matmatmul!(C, A, B)
110
-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}) =generic_matmatmul!(C, A, B)
111
-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}) =generic_matmatmul!(C, A, B)
112
-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}) =generic_matmatmul!(C, A, B)
104
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number) =generic_matmatmul!(C, A, B, a, b)
105
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) =generic_matmatmul!(C, A, B, a, b)
106
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) =generic_matmatmul!(C, A, B, a, b)
107
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) =generic_matmatmul!(C, A, B, a, b)
108
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) =generic_matmatmul!(C, A, B, a, b)
109
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) =generic_matmatmul!(C, A, B, a, b)
110
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) =generic_matmatmul!(C, A, B, a, b)
111
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) =generic_matmatmul!(C, A, B, a, b)
112
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) =generic_matmatmul!(C, A, B, a, b)
113
+
114
+
@staticifv"1.3.0"<=VERSION<=v"1.3.1"
115
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasFloat} =generic_matmatmul!(C, A, B, a, b)
116
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::AbstractGPUVecOrMat{T}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} =generic_matmatmul!(C, A, B, a, b)
117
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::AbstractGPUVecOrMat{T}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} =generic_matmatmul!(C, A, B, a, b)
118
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::AbstractGPUVecOrMat{T}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasFloat} =generic_matmatmul!(C, A, B, a, b)
119
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::AbstractGPUVecOrMat{T}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} =generic_matmatmul!(C, A, B, a, b)
120
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::AbstractGPUVecOrMat{T}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} =generic_matmatmul!(C, A, B, a, b)
121
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, B::AbstractGPUVecOrMat{T}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasFloat} =generic_matmatmul!(C, A, B, a, b)
122
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} =generic_matmatmul!(C, A, B, a, b)
123
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} =generic_matmatmul!(C, A, B, a, b)
124
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} =generic_matmatmul!(C, A, B, a, b)
125
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} =generic_matmatmul!(C, A, B, a, b)
126
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} =generic_matmatmul!(C, A, B, a, b)
127
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} =generic_matmatmul!(C, A, B, a, b)
128
+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasFloat} =generic_matmatmul!(C, A, B, a, b)
0 commit comments