Skip to content

Commit fac25d3

Browse files
kshyattlkdvosJutho
authored
Initial support for CUDA + factorizations (#336)
* Initial support for CUDA + factorizations * Working adjoint * Update test/cuda/factorizations.jl Co-authored-by: Lukas Devos <[email protected]> * Update test/cuda/factorizations.jl Co-authored-by: Jutho <[email protected]> * Fix posdef and copy * Comments --------- Co-authored-by: Lukas Devos <[email protected]> Co-authored-by: Jutho <[email protected]>
1 parent 1cb66b8 commit fac25d3

File tree

5 files changed

+505
-4
lines changed

5 files changed

+505
-4
lines changed

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ function LinearAlgebra.isposdef(t::CuTensorMap)
119119
InnerProductStyle(spacetype(t)) === EuclideanInnerProduct() || return false
120120
for (c, b) in blocks(t)
121121
# do our own hermitian check
122-
isherm = MatrixAlgebraKit.ishermitian(b; atol = eps(real(eltype(b))), rtol = eps(real(eltype(b))))
122+
isherm = MatrixAlgebraKit.ishermitian(b)
123123
isherm || return false
124124
isposdef(Hermitian(b)) || return false
125125
end

src/factorizations/adjoint.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ _adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs.
99
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svd_alg))
1010
_adjoint(alg::AbstractAlgorithm) = alg
1111

12+
_adjoint(alg::MAK.CUSOLVER_HouseholderQR) = MAK.LQViaTransposedQR(alg)
13+
_adjoint(alg::MAK.LQViaTransposedQR) = alg.qr_alg
14+
1215
for f in
1316
[
1417
:svd_compact, :svd_full, :svd_vals,
@@ -108,3 +111,7 @@ function MAK.svd_compact!(t::AdjointTensorMap, F, alg::DiagonalAlgorithm)
108111
F′ = svd_compact!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
109112
return reverse(adjoint.(F′))
110113
end
114+
115+
function LinearAlgebra.isposdef(t::AdjointTensorMap)
116+
return isposdef(adjoint(t))
117+
end

src/tensors/linalg.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,12 @@ LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag ∘ last,
196196

197197
# In-place methods
198198
#------------------
199-
# Wrapping the blocks in a StridedView enables multithreading if JULIA_NUM_THREADS > 1
200-
# TODO: reconsider this strategy, consider spawning different threads for different blocks
201199

202200
# Copy, adjoint and fill:
203201
function Base.copy!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap)
204202
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
205203
for ((c, bdst), (_, bsrc)) in zip(blocks(tdst), blocks(tsrc))
206-
copy!(StridedView(bdst), StridedView(bsrc))
204+
copy!(bdst, bsrc)
207205
end
208206
return tdst
209207
end

0 commit comments

Comments
 (0)