Skip to content

Commit fec19de

Browse files
committed
Respond to comments, update AMD SVD tests
1 parent 19adf31 commit fec19de

File tree

4 files changed

+11
-12
lines changed

4 files changed

+11
-12
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ julia = "1.10"
3535
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3636
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3737
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
38+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3839
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3940
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4041
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -47,3 +48,4 @@ test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "C
4748

4849
[sources]
4950
AMDGPU = {url = "https://github.com/kshyatt/AMDGPU.jl", rev = "ksh/lrmul"}
51+
CUDA = {url = "https://github.com/JuliaGPU/CUDA.jl", rev = "master"}

src/common/matrixproperties.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ See also [`isisometry`](@ref) and [`is_right_isometry`](@ref).
4242
""" is_left_isometry
4343

4444
function is_left_isometry(A::AbstractMatrix; isapprox_kwargs...)
45-
iszero(min(size(A)...)) && return true
45+
iszero(size(A, 2)) && return true
4646
return isapprox(A' * A, LinearAlgebra.I; isapprox_kwargs...)
4747
end
4848
@@ -56,6 +56,6 @@ See also [`isisometry`](@ref) and [`is_left_isometry`](@ref).
5656
""" is_right_isometry
5757
5858
function is_right_isometry(A::AbstractMatrix; isapprox_kwargs...)
59-
iszero(min(size(A)...)) && return true
59+
iszero(size(A, 1)) && return true
6060
return isapprox(A * A', LinearAlgebra.I; isapprox_kwargs...)
6161
end

src/implementations/svd.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -352,17 +352,16 @@ function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::Ab
352352
# both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration)
353353
# if this condition is not met, do the SVD via adjoint
354354
minmn = min(m, n)
355-
At = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
356-
Ut = similar(U')
357-
Vᴴt = similar(Vᴴ')
355+
Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
356+
Uᴴ = similar(U')
357+
V = similar(Vᴴ')
358358
if size(U) == (m, m)
359-
_gpu_gesvd!(At, view(S, 1:minmn, 1), Vᴴt, Ut)
359+
_gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ)
360360
else
361-
_gpu_gesvd!(At, S, Vᴴt, Ut)
361+
_gpu_gesvd!(Aᴴ, S, V, Uᴴ)
362362
end
363-
length(U) > 0 ? adjoint!(U, Ut) : one!(U)
364-
length(Vᴴ) > 0 ? adjoint!(Vᴴ, Vᴴt) : one!(Vᴴ)
365-
conj!(S)
363+
length(U) > 0 && adjoint!(U, Uᴴ)
364+
length(Vᴴ) > 0 && adjoint!(Vᴴ, V)
366365
return U, S, Vᴴ
367366
end
368367

test/amd/svd.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ include(joinpath("..", "utilities.jl"))
1515
k = min(m, n)
1616
algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
1717
@testset "algorithm $alg" for alg in algs
18-
n > m && alg isa ROCSOLVER_QRIteration && continue # not supported
1918
minmn = min(m, n)
2019
A = ROCArray(randn(rng, T, m, n))
2120

@@ -51,7 +50,6 @@ end
5150
@testset "size ($m, $n)" for n in (37, m, 63)
5251
algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
5352
@testset "algorithm $alg" for alg in algs
54-
n > m && alg isa ROCSOLVER_QRIteration && continue # not supported
5553
A = ROCArray(randn(rng, T, m, n))
5654
U, S, Vᴴ = svd_full(A; alg)
5755
@test U isa ROCMatrix{T} && size(U) == (m, m)

0 commit comments

Comments
 (0)