From db7046c24e4e3cff782cefcc9affe9b74b4e3087 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 9 Oct 2025 07:39:18 -0400 Subject: [PATCH 1/2] More diag mul methods --- src/host/linalg.jl | 4 ++-- test/testsuite/linalg.jl | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 8f871868b..4ef6d73d9 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -266,8 +266,8 @@ function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray}, end function LinearAlgebra.mul!(C::Diagonal{<:Any, <:AbstractGPUArray}, - A::AbstractGPUArray, - B::AbstractGPUArray) + A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}, + B::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}) where {T} dc = C.diag d = length(dc) m, n = size(A, 1), size(A, 2) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index f9ac5d924..93c3549fa 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -259,6 +259,11 @@ C = Diagonal(d) mul!(C, a, b) @test collect(C) ≈ Diagonal(collect(a) * collect(b)) + a = transpose(AT(diagm(rand(elty, n)))) + b = adjoint(AT(diagm(rand(elty, n)))) + C = Diagonal(d) + mul!(C, a, b) + @test collect(C) ≈ Diagonal(collect(a) * collect(b)) end end From 085a6baf3ba4f4c63a61a6cd2a9cd239a2476aeb Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 9 Oct 2025 08:40:00 -0400 Subject: [PATCH 2/2] Even more --- src/host/linalg.jl | 7 +++---- test/testsuite/linalg.jl | 3 +++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 4ef6d73d9..0002fdd54 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -282,7 +282,7 @@ end function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, D::Diagonal{<:Any, <:AbstractGPUArray}, - A::AbstractGPUVecOrMat) + A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}) where {T} dd = D.diag d = length(dd) m, n = size(A, 1), size(A, 2) @@ -290,15 +290,14 @@ function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d")) (m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′")) @. B = dd * A - B end function LinearAlgebra.mul!(B::AbstractGPUVecOrMat, D::Diagonal{<:Any, <:AbstractGPUArray}, - A::AbstractGPUVecOrMat, + A::Union{AbstractGPUArray, Adjoint{T,<:AbstractGPUArray{T}}, Transpose{T,<:AbstractGPUArray{T}}}, α::Number, - β::Number) + β::Number) where {T} dd = D.diag d = length(dd) m, n = size(A, 1), size(A, 2) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 93c3549fa..d77a2bb50 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -238,6 +238,9 @@ mul!(X, D, B) mul!(Y, Diagonal(collect(d)), collect(B)) @test collect(X) ≈ Y + mul!(X, D, adjoint(B)) + mul!(Y, Diagonal(collect(d)), collect(adjoint(B))) + @test collect(X) ≈ Y mul!(X, D, B, α, β) mul!(Y, Diagonal(collect(d)), collect(B), α, β) @test collect(X) ≈ Y