Skip to content

Commit 8ff92f9

Browse files
authored
Support norm for Diagonal (#2860)
* Support norm for Diagonal * Fix type sig
1 parent 3da87de commit 8ff92f9

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

lib/cublas/linalg.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasF
121121
return invoke(norm, Tuple{AbstractGPUArray, Real}, x, p)
122122
end
123123
end
124+
LinearAlgebra.norm(x::Diagonal{T, <:StridedCuVector{T}}, p::Real=2) where {T<:Union{Float16, ComplexF16, CublasFloat}} = norm(x.diag, p)
124125
LinearAlgebra.norm2(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasFloat}}) = nrm2(x)
125126

126127
LinearAlgebra.BLAS.asum(x::StridedCuArray{<:CublasFloat}) = asum(length(x), x)

test/libraries/cublas/level1.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ k = 13
152152
CUBLAS.nrm2(dx, result)
153153
@test norm(x) result[]
154154
end
155+
@testset "norm of Diagonal" begin
156+
x = rand(T, m)
157+
dDx = Diagonal(CuArray(x))
158+
Dx = Diagonal(x)
159+
@test norm(dDx, 1) norm(Dx, 1)
160+
@test norm(dDx, 2) norm(Dx, 2)
161+
@test norm(dDx, Inf) norm(Dx, Inf)
162+
end
155163
end # level 1 testset
156164
@testset for T in [Float16, ComplexF16]
157165
A = CuVector(rand(T, m)) # CUDA.rand doesn't work with 16 bit types yet

0 commit comments

Comments
 (0)