Skip to content

Commit 595cedf

Browse files
committed
Support new projections on GPU
1 parent ba9867b commit 595cedf

File tree

5 files changed

+359
-0
lines changed

5 files changed

+359
-0
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,102 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::
5252
return MatrixAlgebraKit.findtruncated(values, strategy)
5353
end
5454

55+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{true})
56+
m, n = size(Au)
57+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
58+
j > n && return
59+
for i in 1:m
60+
@inbounds begin
61+
val = (Au[i, j] - adjoint(Al[j, i])) / 2
62+
Bu[i, j] = val
63+
Bl[j, i] = -adjoint(val)
64+
end
65+
end
66+
return
67+
end
68+
69+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{false})
70+
m, n = size(Au)
71+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
72+
j > n && return
73+
for i in 1:m
74+
@inbounds begin
75+
val = (Au[i, j] + adjoint(Al[j, i])) / 2
76+
Bu[i, j] = val
77+
Bl[j, i] = adjoint(val)
78+
end
79+
end
80+
return
81+
end
82+
83+
function _project_hermitian_diag_kernel(A, B, ::Val{true})
84+
n = size(A, 1)
85+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
86+
j > n && return
87+
@inbounds begin
88+
for i in 1:(j - 1)
89+
val = (A[i, j] - adjoint(A[j, i])) / 2
90+
B[i, j] = val
91+
B[j, i] = -adjoint(val)
92+
end
93+
B[j, j] = MatrixAlgebraKit._imimag(A[j, j])
94+
end
95+
return
96+
end
97+
98+
function _project_hermitian_diag_kernel(A, B, ::Val{false})
99+
n = size(A, 1)
100+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
101+
j > n && return
102+
@inbounds begin
103+
for i in 1:(j - 1)
104+
val = (A[i, j] + adjoint(A[j, i])) / 2
105+
B[i, j] = val
106+
B[j, i] = adjoint(val)
107+
end
108+
B[j, j] = real(A[j, j])
109+
end
110+
return
111+
end
112+
113+
function MatrixAlgebraKit._project_hermitian_offdiag!(
114+
Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti}
115+
) where {anti}
116+
thread_dim = 512
117+
block_dim = cld(size(Au, 2), thread_dim)
118+
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
119+
return nothing
120+
end
121+
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti}
122+
thread_dim = 512
123+
block_dim = cld(size(A, 1), thread_dim)
124+
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
125+
return nothing
126+
end
127+
128+
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
129+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
130+
131+
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all(A .== -adjoint(A))
132+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
133+
134+
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
135+
axes(A) == axes(B) || throw(DimensionMismatch())
136+
function _avgdiff_kernel(A, B)
137+
j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x
138+
j > length(A) && return
139+
@inbounds begin
140+
a = A[j]
141+
b = B[j]
142+
A[j] = (a + b) / 2
143+
B[j] = b - a
144+
end
145+
return
146+
end
147+
thread_dim = 512
148+
block_dim = cld(length(A), thread_dim)
149+
@roc groupsize = thread_dim gridsize = block_dim _avgdiff_kernel(A, B)
150+
return A, B
151+
end
152+
55153
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_
99
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
1010
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
1111
using CUDA
12+
using CUDA: i32
1213
using LinearAlgebra
1314
using LinearAlgebra: BlasFloat
1415

@@ -58,4 +59,102 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::T
5859
return MatrixAlgebraKit.findtruncated(values, strategy)
5960
end
6061

62+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{true})
63+
m, n = size(Au)
64+
j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
65+
j > n && return
66+
for i in 1:m
67+
@inbounds begin
68+
val = (Au[i, j] - adjoint(Al[j, i])) / 2
69+
Bu[i, j] = val
70+
Bl[j, i] = -adjoint(val)
71+
end
72+
end
73+
return
74+
end
75+
76+
function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{false})
77+
m, n = size(Au)
78+
j = threadIdx().x + (blockIdx().x - 1) * blockDim().x
79+
j > n && return
80+
for i in 1:m
81+
@inbounds begin
82+
val = (Au[i, j] + adjoint(Al[j, i])) / 2
83+
Bu[i, j] = val
84+
Bl[j, i] = adjoint(val)
85+
end
86+
end
87+
return
88+
end
89+
90+
function _project_hermitian_diag_kernel(A, B, ::Val{true})
91+
n = size(A, 1)
92+
j = threadIdx().x + (blockIdx().x - 1) * blockDim().x
93+
j > n && return
94+
@inbounds begin
95+
for i in 1i32:(j - 1i32)
96+
val = (A[i, j] - adjoint(A[j, i])) / 2
97+
B[i, j] = val
98+
B[j, i] = -adjoint(val)
99+
end
100+
B[j, j] = MatrixAlgebraKit._imimag(A[j, j])
101+
end
102+
return
103+
end
104+
105+
function _project_hermitian_diag_kernel(A, B, ::Val{false})
106+
n = size(A, 1)
107+
j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
108+
j > n && return
109+
@inbounds begin
110+
for i in 1i32:(j - 1i32)
111+
val = (A[i, j] + adjoint(A[j, i])) / 2
112+
B[i, j] = val
113+
B[j, i] = adjoint(val)
114+
end
115+
B[j, j] = real(A[j, j])
116+
end
117+
return
118+
end
119+
120+
function MatrixAlgebraKit._project_hermitian_offdiag!(
121+
Au::StridedCuMatrix, Al::StridedCuMatrix, Bu::StridedCuMatrix, Bl::StridedCuMatrix, ::Val{anti}
122+
) where {anti}
123+
thread_dim = 512
124+
block_dim = cld(size(Au, 2), thread_dim)
125+
@cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
126+
return nothing
127+
end
128+
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::StridedCuMatrix, ::Val{anti}) where {anti}
129+
thread_dim = 512
130+
block_dim = cld(size(A, 1), thread_dim)
131+
@cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
132+
return nothing
133+
end
134+
135+
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = all(A .== adjoint(A))
136+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
137+
138+
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = all(A .== -adjoint(A))
139+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
140+
141+
function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
142+
axes(A) == axes(B) || throw(DimensionMismatch())
143+
function _avgdiff_kernel(A, B)
144+
j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
145+
j > length(A) && return
146+
@inbounds begin
147+
a = A[j]
148+
b = B[j]
149+
A[j] = (a + b) / 2
150+
B[j] = b - a
151+
end
152+
return
153+
end
154+
thread_dim = 512
155+
block_dim = cld(length(A), thread_dim)
156+
@cuda threads = thread_dim blocks = block_dim _avgdiff_kernel(A, B)
157+
return A, B
158+
end
159+
61160
end

test/amd/projections.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, norm
6+
using AMDGPU
7+
8+
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
10+
@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats
11+
rng = StableRNG(123)
12+
m = 54
13+
noisefactor = eps(real(T))^(3 / 4)
14+
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15+
A = ROCArray(randn(rng, T, m, m))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
19+
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
Bh_approx = Bh + noisefactor * Aa
25+
@test !ishermitian(Bh_approx)
26+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
27+
28+
Ba = project_antihermitian(A, alg)
29+
@test isantihermitian(Ba)
30+
@test Ba Aa
31+
@test A == Ac
32+
Ba_approx = Ba + noisefactor * Ah
33+
@test !isantihermitian(Ba_approx)
34+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
35+
36+
Bh = project_hermitian!(Ac, alg)
37+
@test Bh === Ac
38+
@test ishermitian(Bh)
39+
@test Bh Ah
40+
41+
copy!(Ac, A)
42+
Ba = project_antihermitian!(Ac, alg)
43+
@test Ba === Ac
44+
@test isantihermitian(Ba)
45+
@test Ba Aa
46+
end
47+
end
48+
49+
@testset "project_isometric! for T = $T" for T in BLASFloats
50+
rng = StableRNG(123)
51+
m = 54
52+
@testset "size ($m, $n)" for n in (37, m)
53+
k = min(m, n)
54+
svdalgs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
55+
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
56+
@testset "algorithm $alg" for alg in algs
57+
A = ROCArray(randn(rng, T, m, n))
58+
W = project_isometric(A, alg)
59+
@test isisometric(W)
60+
W2 = project_isometric(W, alg)
61+
@test W2 W # stability of the projection
62+
@test W * (W' * A) ≈ A
63+
64+
Ac = similar(A)
65+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
66+
@test W2 === W
67+
@test isisometric(W)
68+
69+
# test that W is closer to A then any other isometry
70+
for k in 1:10
71+
δA = ROCArray(randn(rng, T, m, n))
72+
W = project_isometric(A, alg)
73+
W2 = project_isometric(A + δA / 100, alg)
74+
@test norm(A - W2) > norm(A - W)
75+
end
76+
end
77+
end
78+
end

test/cuda/projections.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, Diagonal, norm
6+
using CUDA
7+
8+
const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
9+
10+
@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats
11+
rng = StableRNG(123)
12+
m = 54
13+
noisefactor = eps(real(T))^(3 / 4)
14+
for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64))
15+
A = CuArray(randn(rng, T, m, m))
16+
Ah = (A + A') / 2
17+
Aa = (A - A') / 2
18+
Ac = copy(A)
19+
20+
Bh = project_hermitian(A, alg)
21+
@test ishermitian(Bh)
22+
@test Bh Ah
23+
@test A == Ac
24+
Bh_approx = Bh + noisefactor * Aa
25+
@test !ishermitian(Bh_approx)
26+
@test ishermitian(Bh_approx; rtol = 10 * noisefactor)
27+
28+
Ba = project_antihermitian(A, alg)
29+
@test isantihermitian(Ba)
30+
@test Ba Aa
31+
@test A == Ac
32+
Ba_approx = Ba + noisefactor * Ah
33+
@test !isantihermitian(Ba_approx)
34+
@test isantihermitian(Ba_approx; rtol = 10 * noisefactor)
35+
36+
Bh = project_hermitian!(Ac, alg)
37+
@test Bh === Ac
38+
@test ishermitian(Bh)
39+
@test Bh Ah
40+
41+
copy!(Ac, A)
42+
Ba = project_antihermitian!(Ac, alg)
43+
@test Ba === Ac
44+
@test isantihermitian(Ba)
45+
@test Ba Aa
46+
end
47+
end
48+
49+
@testset "project_isometric! for T = $T" for T in BLASFloats
50+
rng = StableRNG(123)
51+
m = 54
52+
@testset "size ($m, $n)" for n in (37, m)
53+
k = min(m, n)
54+
svdalgs = (CUSOLVER_SVDPolar(), CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
55+
algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO
56+
@testset "algorithm $alg" for alg in algs
57+
A = CuArray(randn(rng, T, m, n))
58+
W = project_isometric(A, alg)
59+
@test isisometric(W)
60+
W2 = project_isometric(W, alg)
61+
@test W2 W # stability of the projection
62+
@test W * (W' * A) ≈ A
63+
64+
Ac = similar(A)
65+
W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg)
66+
@test W2 === W
67+
@test isisometric(W)
68+
69+
# test that W is closer to A then any other isometry
70+
for k in 1:10
71+
δA = CuArray(randn(rng, T, m, n))
72+
W = project_isometric(A, alg)
73+
W2 = project_isometric(A + δA / 100, alg)
74+
@test norm(A - W2) > norm(A - W)
75+
end
76+
end
77+
end
78+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ if CUDA.functional()
6363
@safetestset "CUDA LQ" begin
6464
include("cuda/lq.jl")
6565
end
66+
@safetestset "CUDA Projections" begin
67+
include("cuda/projections.jl")
68+
end
6669
@safetestset "CUDA SVD" begin
6770
include("cuda/svd.jl")
6871
end
@@ -82,6 +85,9 @@ if AMDGPU.functional()
8285
@safetestset "AMDGPU LQ" begin
8386
include("amd/lq.jl")
8487
end
88+
@safetestset "AMDGPU Projections" begin
89+
include("amd/projections.jl")
90+
end
8591
@safetestset "AMDGPU SVD" begin
8692
include("amd/svd.jl")
8793
end

0 commit comments

Comments
 (0)