Skip to content

Commit 0dce4de

Browse files
author
Katharine Hyatt
committed
Runic updates
1 parent 64954ef commit 0dce4de

File tree

9 files changed

+282
-208
lines changed

9 files changed

+282
-208
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,31 +21,31 @@ function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <:
2121
qr_alg = ROCSOLVER_HouseholderQR(; kwargs...)
2222
return LQViaTransposedQR(qr_alg)
2323
end
24-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
24+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2525
return ROCSOLVER_QRIteration(; kwargs...)
2626
end
27-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
28-
throw(ErrorException("AMDGPU has no support for general eigensolving"))
27+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
28+
throw(ErrorException("AMDGPU has no support for general eigensolving"))
2929
end
3030
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
3131
return ROCSOLVER_DivideAndConquer(; kwargs...)
3232
end
3333

3434
# include for block sector support
35-
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
35+
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: ROCVecOrMat{T}}
3636
return ROCSOLVER_HouseholderQR(; kwargs...)
3737
end
38-
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
38+
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: ROCVecOrMat{T}}
3939
qr_alg = ROCSOLVER_HouseholderQR(; kwargs...)
4040
return LQViaTransposedQR(qr_alg)
4141
end
42-
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
42+
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: ROCVecOrMat{T}}
4343
return ROCSOLVER_Jacobi(; kwargs...)
4444
end
45-
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
46-
throw(ErrorException("AMDGPU has no support for general eigensolving"))
45+
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: ROCVecOrMat{T}}
46+
throw(ErrorException("AMDGPU has no support for general eigensolving"))
4747
end
48-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:ROCVecOrMat{T}}
48+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: ROCVecOrMat{T}}
4949
return ROCSOLVER_DivideAndConquer(; kwargs...)
5050
end
5151

@@ -107,7 +107,7 @@ function _project_hermitian_diag_kernel(A, B, ::Val{true})
107107
j > n && return
108108
@inbounds begin
109109
for i in 1:(j - 1)
110-
val = (A[i, j] - adjoint(A[j, i])) /2
110+
val = (A[i, j] - adjoint(A[j, i])) / 2
111111
B[i, j] = val
112112
B[j, i] = -adjoint(val)
113113
end
@@ -135,22 +135,22 @@ function MatrixAlgebraKit._project_hermitian_offdiag!(
135135
Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti}
136136
) where {anti}
137137
thread_dim = 512
138-
block_dim = cld(size(Au, 2), thread_dim)
139-
@roc groupsize=thread_dim gridsize=block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
138+
block_dim = cld(size(Au, 2), thread_dim)
139+
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
140140
return nothing
141141
end
142142
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti}
143143
thread_dim = 512
144-
block_dim = cld(size(A, 1), thread_dim)
145-
@roc groupsize=thread_dim gridsize=block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
144+
block_dim = cld(size(A, 1), thread_dim)
145+
@roc groupsize = thread_dim gridsize = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
146146
return nothing
147147
end
148148

149-
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all( A .== adjoint(A))
150-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all( A.diag .== adjoint(A.diag))
149+
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
150+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
151151

152-
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all( A .== -adjoint(A))
153-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all( A.diag .== -adjoint(A.diag))
152+
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all(A .== -adjoint(A))
153+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
154154

155155
function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
156156
axes(A) == axes(B) || throw(DimensionMismatch())
@@ -160,14 +160,14 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
160160
@inbounds begin
161161
a = A[j]
162162
b = B[j]
163-
A[j] = (a+b)/2
163+
A[j] = (a + b) / 2
164164
B[j] = b - a
165165
end
166166
return
167167
end
168168
thread_dim = 512
169-
block_dim = cld(length(A), thread_dim)
170-
@roc groupsize=thread_dim gridsize=block_dim _avgdiff_kernel(A, B)
169+
block_dim = cld(length(A), thread_dim)
170+
@roc groupsize = thread_dim gridsize = block_dim _avgdiff_kernel(A, B)
171171
return A, B
172172
end
173173
end

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,38 @@ using CUDA: i32
1616

1717
include("yacusolver.jl")
1818

19-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
19+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2020
return CUSOLVER_HouseholderQR(; kwargs...)
2121
end
22-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
22+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2323
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
2424
return LQViaTransposedQR(qr_alg)
2525
end
26-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
26+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
2727
return CUSOLVER_QRIteration(; kwargs...)
2828
end
29-
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
29+
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
3030
return CUSOLVER_Simple(; kwargs...)
3131
end
32-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}}
32+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT <: BlasFloat, T <: StridedCuMatrix{TT}}
3333
return CUSOLVER_DivideAndConquer(; kwargs...)
3434
end
3535

3636
# include for block sector support
37-
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
37+
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
3838
return CUSOLVER_HouseholderQR(; kwargs...)
3939
end
40-
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
40+
function MatrixAlgebraKit.default_lq_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
4141
qr_alg = CUSOLVER_HouseholderQR(; kwargs...)
4242
return LQViaTransposedQR(qr_alg)
4343
end
44-
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
44+
function MatrixAlgebraKit.default_svd_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
4545
return CUSOLVER_Jacobi(; kwargs...)
4646
end
47-
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
47+
function MatrixAlgebraKit.default_eig_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
4848
return CUSOLVER_Simple(; kwargs...)
4949
end
50-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T,2,SubArray{T,1,A,Tuple{UnitRange{Int}},true},Tuple{}}}; kwargs...) where {T<:BlasFloat, A<:CuVecOrMat{T}}
50+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
5151
return CUSOLVER_DivideAndConquer(; kwargs...)
5252
end
5353

@@ -112,7 +112,7 @@ function _project_hermitian_diag_kernel(A, B, ::Val{true})
112112
j > n && return
113113
@inbounds begin
114114
for i in 1i32:(j - 1i32)
115-
val = (A[i, j] - adjoint(A[j, i])) /2
115+
val = (A[i, j] - adjoint(A[j, i])) / 2
116116
B[i, j] = val
117117
B[j, i] = -adjoint(val)
118118
end
@@ -140,22 +140,22 @@ function MatrixAlgebraKit._project_hermitian_offdiag!(
140140
Au::StridedCuMatrix, Al::StridedCuMatrix, Bu::StridedCuMatrix, Bl::StridedCuMatrix, ::Val{anti}
141141
) where {anti}
142142
thread_dim = 512
143-
block_dim = cld(size(Au, 2), thread_dim)
144-
@cuda threads=thread_dim blocks=block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
143+
block_dim = cld(size(Au, 2), thread_dim)
144+
@cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti))
145145
return nothing
146146
end
147147
function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::StridedCuMatrix, ::Val{anti}) where {anti}
148148
thread_dim = 512
149-
block_dim = cld(size(A, 1), thread_dim)
150-
@cuda threads=thread_dim blocks=block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
149+
block_dim = cld(size(A, 1), thread_dim)
150+
@cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))
151151
return nothing
152152
end
153153

154-
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = all( A .== adjoint(A))
155-
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all( A.diag .== adjoint(A.diag))
154+
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = all(A .== adjoint(A))
155+
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
156156

157-
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = all( A .== -adjoint(A))
158-
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all( A.diag .== -adjoint(A.diag))
157+
MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = all(A .== -adjoint(A))
158+
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
159159

160160
function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
161161
axes(A) == axes(B) || throw(DimensionMismatch())
@@ -165,14 +165,14 @@ function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
165165
@inbounds begin
166166
a = A[j]
167167
b = B[j]
168-
A[j] = (a+b)/2
168+
A[j] = (a + b) / 2
169169
B[j] = b - a
170170
end
171171
return
172172
end
173173
thread_dim = 512
174-
block_dim = cld(length(A), thread_dim)
175-
@cuda threads=thread_dim blocks=block_dim _avgdiff_kernel(A, B)
174+
block_dim = cld(length(A), thread_dim)
175+
@cuda threads = thread_dim blocks = block_dim _avgdiff_kernel(A, B)
176176
return A, B
177177
end
178178

ext/MatrixAlgebraKitCUDAExt/yacusolver.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,16 @@ for (bname, fname, elty, relty) in
192192
)
193193
@eval begin
194194
#! format: off
195-
function gesvdj!(A::StridedCuMatrix{$elty},
196-
S::StridedCuVector{$relty}=similar(A, $relty, min(size(A)...)),
197-
U::StridedCuMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)),
198-
Vᴴ::StridedCuMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2));
199-
tol::$relty=eps($relty),
200-
max_sweeps::Int=100,
201-
kwargs...)
202-
#! format: on
195+
function gesvdj!(
196+
A::StridedCuMatrix{$elty},
197+
S::StridedCuVector{$relty} = similar(A, $relty, min(size(A)...)),
198+
U::StridedCuMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
199+
Vᴴ::StridedCuMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2));
200+
tol::$relty = eps($relty),
201+
max_sweeps::Int = 100,
202+
kwargs...
203+
)
204+
#! format: on
203205
chkstride1(A, U, Vᴴ, S)
204206
m, n = size(A)
205207
minmn = min(m, n)

src/implementations/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ function _gpu_gesvd_maybe_transpose!(A::AbstractMatrix, S::AbstractVector, U::Ab
354354
minmn = min(m, n)
355355
Aᴴ = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A')
356356
Uᴴ = similar(U')
357-
V = similar(Vᴴ')
357+
V = similar(Vᴴ')
358358
if size(U) == (m, m)
359359
_gpu_gesvd!(Aᴴ, view(S, 1:minmn, 1), V, Uᴴ)
360360
else

0 commit comments

Comments
 (0)