Skip to content

Commit b1be744

Browse files
bors[bot]haampie
andauthored
Merge #253
253: Fix 5 arg mul! r=haampie a=haampie Move from 3-arg mul to 5-arg mul to make sure GPUArrays's generic mul does not steal from CuArrays. (Btw, I could not yet test because cyclops was unreachable :( ) Co-authored-by: Harmen Stoppels <[email protected]>
2 parents 59f4887 + caeb995 commit b1be744

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

src/host/linalg.jl

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666

6767
## matrix multiplication
6868

69-
function generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R}
69+
function generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{S}, a::Number, b::Number) where {T,S,R}
7070
if size(A,2) != size(B,1)
7171
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
7272
end
@@ -92,7 +92,7 @@ function generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}
9292
for k in 1:size(A,2)
9393
Ctmp += A[i, k]*B[k, j]
9494
end
95-
C[i,j] = Ctmp
95+
C[i,j] = Ctmp*a + C[i,j]*b
9696
end
9797

9898
return
@@ -101,15 +101,33 @@ function generic_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractGPUVecOrMat{T}
101101
C
102102
end
103103

104-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat) = generic_matmatmul!(C, A, B)
105-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}) = generic_matmatmul!(C, A, B)
106-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}) = generic_matmatmul!(C, A, B)
107-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat) = generic_matmatmul!(C, A, B)
108-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat) = generic_matmatmul!(C, A, B)
109-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}) = generic_matmatmul!(C, A, B)
110-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}) = generic_matmatmul!(C, A, B)
111-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}) = generic_matmatmul!(C, A, B)
112-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}) = generic_matmatmul!(C, A, B)
104+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
105+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
106+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
107+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
108+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
109+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
110+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
111+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
112+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
113+
114+
@static if v"1.3.0" <= VERSION <= v"1.3.1"
115+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasFloat} = generic_matmatmul!(C, A, B, a, b)
116+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::AbstractGPUVecOrMat{T}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} = generic_matmatmul!(C, A, B, a, b)
117+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::AbstractGPUVecOrMat{T}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} = generic_matmatmul!(C, A, B, a, b)
118+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::AbstractGPUVecOrMat{T}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasFloat} = generic_matmatmul!(C, A, B, a, b)
119+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::AbstractGPUVecOrMat{T}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} = generic_matmatmul!(C, A, B, a, b)
120+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::AbstractGPUVecOrMat{T}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} = generic_matmatmul!(C, A, B, a, b)
121+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, B::AbstractGPUVecOrMat{T}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasFloat} = generic_matmatmul!(C, A, B, a, b)
122+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} = generic_matmatmul!(C, A, B, a, b)
123+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} = generic_matmatmul!(C, A, B, a, b)
124+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} = generic_matmatmul!(C, A, B, a, b)
125+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} = generic_matmatmul!(C, A, B, a, b)
126+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasReal} = generic_matmatmul!(C, A, B, a, b)
127+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasComplex} = generic_matmatmul!(C, A, B, a, b)
128+
LinearAlgebra.mul!(C::AbstractGPUVecOrMat{T}, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat{T}}, a::Union{Bool,T}, b::Union{Bool,T}) where {T<:LinearAlgebra.BLAS.BlasFloat} = generic_matmatmul!(C, A, B, a, b)
129+
end
130+
113131

114132
function generic_rmul!(X::AbstractGPUArray, s::Number)
115133
gpu_call(X, s) do ctx, X, s

test/testsuite/linalg.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,28 +61,31 @@ function test_linalg(AT)
6161
@test f(A, d) == Array(f!(AT(A), d))
6262
end
6363

64-
@testset "matrix multiplication" begin
65-
for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in supported_eltypes()
66-
@test compare(*, AT, rand(T, a), rand(T, b))
64+
@testset "$T gemv y := $f(A) * x * a + y * b" for f in (identity, transpose, adjoint), T in supported_eltypes()
65+
y, A, x = rand(T, 4), rand(T, 4, 4), rand(T, 4)
6766

68-
if length(a) > 1
69-
@test compare(*, AT, transpose(rand(T, reverse(a))), rand(T, b))
70-
@test compare(*, AT, adjoint(rand(T, reverse(a))), rand(T, b))
71-
end
67+
# workaround for https://github.com/JuliaLang/julia/issues/35163#issue-584248084
68+
T <: Integer && (y .%= T(10); A .%= T(10); x .%= T(10))
69+
70+
@test compare(*, AT, f(A), x)
71+
@test compare(mul!, AT, y, f(A), x)
72+
@test compare(mul!, AT, y, f(A), x, Ref(T(4)), Ref(T(5)))
73+
end
7274

73-
if length(b) > 1
74-
@test compare(*, AT, rand(T, a), transpose(rand(T, reverse(b))))
75-
@test compare(*, AT, rand(T, a), adjoint(rand(T, reverse(b))))
76-
end
75+
@testset "$T gemm C := $f(A) * $g(B) * a + C * b" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), T in supported_eltypes()
76+
A, B, C = rand(T, 4, 4), rand(T, 4, 4), rand(T, 4, 4)
7777

78-
if length(a) > 1 && length(b) > 1
79-
@test compare(*, AT, transpose(rand(T, reverse(a))), transpose(rand(T, reverse(b))))
80-
@test compare(*, AT, adjoint(rand(T, reverse(a))), adjoint(rand(T, reverse(b))))
81-
end
78+
# workaround for https://github.com/JuliaLang/julia/issues/35163#issue-584248084
79+
T <: Integer && (A .%= T(10); B .%= T(10); C .%= T(10))
80+
81+
@test compare(*, AT, f(A), g(B))
82+
@test compare(mul!, AT, C, f(A), g(B))
83+
@test compare(mul!, AT, C, f(A), g(B), Ref(T(4)), Ref(T(5)))
84+
end
8285

83-
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
84-
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
85-
end
86+
@testset "lmul! and rmul!" for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in supported_eltypes()
87+
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
88+
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
8689
end
8790
end
8891
end

0 commit comments

Comments
 (0)