Skip to content

Commit 5761b1f

Browse files
authored
More mul extensions (#2862)
* More generic_matmul extensions * Add more tests
1 parent 8ff92f9 commit 5761b1f

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

lib/cublas/linalg.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,12 @@ function LinearAlgebra.generic_matmatmul!(C::StridedCuVecOrMat, tA, tB, A::Strid
271271
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
272272
end
273273

274+
# fallback for weird Complex edge cases
275+
const AdjOrTransOrCuMatrix{T} = Union{StridedCuMatrix{T}, AdjOrTrans{<:T,<:StridedCuMatrix{T}}}
276+
LinearAlgebra.mul!(C::StridedCuVecOrMat{T}, A::AdjOrTransOrCuMatrix{T}, B::Adjoint{T, <:Transpose{T, <:StridedCuMatrix{T}}}, α::Number, β::Number) where {T<:Complex} = mul!(C, A, conj(parent(parent(B))), α, β)
277+
LinearAlgebra.mul!(C::StridedCuVecOrMat{T}, A::AdjOrTransOrCuMatrix{T}, B::Transpose{T, <:Adjoint{T, <:StridedCuMatrix{T}}}, α::Number, β::Number) where {T<:Complex} = mul!(C, A, conj(parent(parent(B))), α, β)
278+
LinearAlgebra.mul!(C::StridedCuVecOrMat{T}, A::Adjoint{T, <:Transpose{T, <:StridedCuMatrix{T}}}, B::AdjOrTransOrCuMatrix{T}, α::Number, β::Number) where {T<:Complex} = mul!(C, conj(parent(parent(A))), B, α, β)
279+
LinearAlgebra.mul!(C::StridedCuVecOrMat{T}, A::Transpose{T, <:Adjoint{T, <:StridedCuMatrix{T}}}, B::AdjOrTransOrCuMatrix{T}, α::Number, β::Number) where {T<:Complex} = mul!(C, conj(parent(parent(A))), B, α, β)
274280

275281
# triangular
276282

@@ -280,7 +286,6 @@ LinearAlgebra.generic_mattrimul!(C::StridedCuMatrix{T}, uploc, isunitc, tfun::Fu
280286
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
281287

282288
## tri-tri-mul!
283-
const AdjOrTransOrCuMatrix{T} = Union{StridedCuMatrix{T}, AdjOrTrans{<:T,<:StridedCuMatrix}}
284289
function LinearAlgebra.generic_trimatmul!(C::StridedCuMatrix{T}, uplocA, isunitcA, tfunA::Function, A::StridedCuMatrix{T}, triB::UpperOrLowerTriangular{T,<:AdjOrTransOrCuMatrix{T}}) where {T<:CublasFloat}
285290
uplocB = LinearAlgebra.uplo_char(triB)
286291
isunitcB = LinearAlgebra.isunit_char(triB)

test/libraries/cublas/level3/gemm.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,30 @@ k = 13
1818
@testset "level 3" begin
1919
@testset for elty in [Float32, Float64, ComplexF32, ComplexF64]
2020

21-
@testset "mul! C = $f(A) * $g(B) * $Ts(a) + C * $Ts(b)" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), Ts in (Int, elty)
21+
@testset "mul! C = $f(A) * $g(B) * $Ts(a) + C * $Ts(b)" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), Ts in (Int, elty)
2222
C, A, B = rand(elty, 5, 5), rand(elty, 5, 5), rand(elty, 5, 5)
2323
dC, dA, dB = CuArray(C), CuArray(A), CuArray(B)
2424
mul!(dC, f(dA), g(dB), Ts(1), Ts(2))
2525
mul!(C, f(A), g(B), Ts(1), Ts(2))
2626
@test Array(dC) C
2727
end
2828

29+
@testset "mul! C = $f(A) * $f($g(B)) * $Ts(a) + C * $Ts(b)" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), Ts in (Int, elty)
30+
C, A, B = rand(elty, 5, 5), rand(elty, 5, 5), rand(elty, 5, 5)
31+
dC, dA, dB = CuArray(C), CuArray(A), CuArray(B)
32+
mul!(dC, f(dA), f(g(dB)), Ts(1), Ts(2))
33+
mul!(C, f(A), f(g(B)), Ts(1), Ts(2))
34+
@test Array(dC) C
35+
end
36+
37+
@testset "mul! C = $g($f(A)) * $g(B) * $Ts(a) + C * $Ts(b)" for f in (identity, transpose, adjoint), g in (identity, transpose, adjoint), Ts in (Int, elty)
38+
C, A, B = rand(elty, 5, 5), rand(elty, 5, 5), rand(elty, 5, 5)
39+
dC, dA, dB = CuArray(C), CuArray(A), CuArray(B)
40+
mul!(dC, g(f(dA)), g(dB), Ts(1), Ts(2))
41+
mul!(C, g(f(A)), g(B), Ts(1), Ts(2))
42+
@test Array(dC) C
43+
end
44+
2945
@testset "hermitian" begin
3046
C, A, B = rand(elty, 5, 5), Hermitian(rand(elty, 5, 5)), rand(elty, 5, 5)
3147
dC, dA, dB = CuArray(C), Hermitian(CuArray(A)), CuArray(B)

0 commit comments

Comments
 (0)