Skip to content

Commit b724f6b

Browse files
authored
Merge branch 'master' into sds/rng_wavefront64
2 parents 7561d29 + 4fcec17 commit b724f6b

File tree

4 files changed

+94
-1
lines changed

4 files changed

+94
-1
lines changed

src/blas/highlevel.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,41 @@ if VERSION ≥ v"1.12-"
330330
LinearAlgebra.copytrito!(B::Matrix{T}, A::ROCMatrix{T}, uplo::AbstractChar) where {T <: ROCBLASFloat} =
331331
invoke(LinearAlgebra.copytrito!, Tuple{AbstractMatrix, AbstractMatrix, AbstractChar}, B, A, uplo)
332332
end
333+
334+
function LinearAlgebra.lmul!(A::Diagonal{T,<:ROCVector{T}}, B::ROCMatrix{T}) where {T<:ROCBLASFloat}
335+
return dgmm!('L', B, A.diag, B)
336+
end
337+
338+
function LinearAlgebra.rmul!(A::ROCMatrix{T}, B::Diagonal{T,<:ROCVector{T}}) where {T<:ROCBLASFloat}
339+
return dgmm!('R', A, B.diag, A)
340+
end
341+
342+
# eltypes do not match
343+
function LinearAlgebra.lmul!(A::Diagonal{T,<:ROCVector{T}}, B::ROCMatrix) where {T<:ROCBLASFloat}
344+
@. B = A.diag * B
345+
return B
346+
end
347+
function LinearAlgebra.lmul!(A::Diagonal{Td,<:ROCVector{Td}}, B::Transpose{Tt, <:ROCMatrix{Tt}}) where {Td<:ROCBLASFloat, Tt<:ROCBLASFloat}
348+
@. B = A.diag * B
349+
return B
350+
end
351+
function LinearAlgebra.lmul!(A::Diagonal{Td,<:ROCVector{Td}}, B::Adjoint{Tt, <:ROCMatrix{Tt}}) where {Td<:ROCBLASFloat, Tt<:ROCBLASFloat}
352+
@. B = A.diag * B
353+
return B
354+
end
355+
# eltypes do not match
356+
function LinearAlgebra.rmul!(A::ROCMatrix, B::Diagonal{T,<:ROCVector{T}}) where {T<:ROCBLASFloat}
357+
At = transpose(A)
358+
@. At = B.diag * At
359+
return A
360+
end
361+
function LinearAlgebra.rmul!(A::Transpose{Tt, <:ROCMatrix{Tt}}, B::Diagonal{Td,<:ROCVector{Td}}) where {Td<:ROCBLASFloat, Tt<:ROCBLASFloat}
362+
At = parent(A)
363+
@. At = B.diag * At
364+
return transpose(At)
365+
end
366+
function LinearAlgebra.rmul!(A::Adjoint{Tt, <:ROCMatrix{Tt}}, B::Diagonal{Td,<:ROCVector{Td}}) where {Td<:ROCBLASFloat, Tt<:ROCBLASFloat}
367+
At = parent(A)
368+
@. At = adjoint(B.diag) * At
369+
return adjoint(At)
370+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
77
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
88
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
99
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
10+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1011
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1112
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/device/launch.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ end
131131
# end
132132
end
133133

134-
if !iszero(AMDGPU.HIP.properties(AMDGPU.device()).cooperativeLaunch)
134+
if VERSION >= v"1.12-" && !iszero(AMDGPU.HIP.properties(AMDGPU.device()).cooperativeLaunch)
135135
@testset "Cooperative Groups" begin
136136
function test_kernel!(x)
137137
block_row, block_col = workgroupIdx().x, workgroupIdx().y
@@ -151,6 +151,8 @@ if !iszero(AMDGPU.HIP.properties(AMDGPU.device()).cooperativeLaunch)
151151
end
152152
AMDGPU.Device.sync_grid()
153153
end
154+
155+
return nothing
154156
end
155157

156158
n = 4

test/rocarray/blas.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,4 +288,56 @@ end
288288
end
289289
end
290290

291+
@testset "Extension" begin
292+
for TA in (Float32, Float64, ComplexF32, ComplexF64), TB in (Float32, Float64)
293+
x = rand(TB, m)
294+
d_x = ROCArray(x)
295+
XA = rand(TA, m, n)
296+
d_XA = ROCArray(XA)
297+
d_X = Diagonal(d_x)
298+
lmul!(d_X, d_XA)
299+
@test Array(d_XA) Diagonal(x) * XA
300+
301+
x = rand(TB, m)
302+
d_x = ROCArray(x)
303+
XA = rand(TA, n, m)
304+
d_AX = transpose(ROCArray(XA))
305+
d_X = Diagonal(d_x)
306+
lmul!(d_X, d_AX)
307+
@test Array(d_AX) Diagonal(x) * transpose(XA)
308+
309+
x = rand(TB, m)
310+
d_x = ROCArray(x)
311+
XA = rand(TA, n, m)
312+
d_AX = adjoint(ROCArray(XA))
313+
d_X = Diagonal(d_x)
314+
lmul!(d_X, d_AX)
315+
@test Array(d_AX) Diagonal(x) * adjoint(XA)
316+
317+
y = rand(TB, n)
318+
d_y = ROCArray(y)
319+
AY = rand(TA, m, n)
320+
d_AY = ROCArray(AY)
321+
d_Y = Diagonal(d_y)
322+
rmul!(d_AY, d_Y)
323+
@test Array(d_AY) AY * Diagonal(y)
324+
325+
y = rand(TB, n)
326+
d_y = ROCArray(y)
327+
AY = rand(TA, n, m)
328+
d_YA = transpose(ROCArray(AY))
329+
d_Y = Diagonal(d_y)
330+
d_YA = rmul!(d_YA, d_Y)
331+
@test Array(d_YA) transpose(AY) * Diagonal(y)
332+
333+
y = rand(TB, n)
334+
d_y = ROCArray(y)
335+
AY = rand(TA, n, m)
336+
d_YA = adjoint(ROCArray(AY))
337+
d_Y = Diagonal(d_y)
338+
d_YA = rmul!(d_YA, d_Y)
339+
@test Array(d_YA) adjoint(AY) * Diagonal(y)
340+
end
341+
end
342+
291343
end

0 commit comments

Comments
 (0)