Skip to content

Commit 92cd3c2

Browse files
committed
Fix posdef and copy
1 parent 739fc99 commit 92cd3c2

File tree

3 files changed

+12
-14
lines changed

3 files changed

+12
-14
lines changed

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,15 @@ function LinearAlgebra.isposdef(t::CuTensorMap)
121121
# do our own hermitian check
122122
isherm = MatrixAlgebraKit.ishermitian(b; atol = eps(real(eltype(b))), rtol = eps(real(eltype(b))))
123123
isherm || return false
124-
isposdef(Hermitian(b)) || return false
124+
isposdef(project_hermitian!(b)) || return false
125125
end
126126
return true
127127
end
128128

129+
function LinearAlgebra.isposdef(t::TensorKit.AdjointTensorMap{T, S, N₁, N₂, <:CuTensorMap}) where {T, S, N₁, N₂}
130+
return isposdef(adjoint(t))
131+
end
132+
129133
function Base.promote_rule(
130134
::Type{<:TT₁},
131135
::Type{<:TT₂}
@@ -164,11 +168,3 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
164168
return tf
165169
end
166170
end
167-
168-
function Base.copy!(tdst::CuTensorMap, tsrc::AdjointTensorMap)
169-
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
170-
for ((c, bdst), (_, bsrc)) in zip(blocks(tdst), blocks(tsrc))
171-
copy!(bdst, bsrc)
172-
end
173-
return tdst
174-
end

src/tensors/linalg.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,15 @@ LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag ∘ last,
197197
# In-place methods
198198
#------------------
199199
# Wrapping the blocks in a StridedView enables multithreading if JULIA_NUM_THREADS > 1
200+
# but can cause problems with underlying array types (like CuArray) that don't yet play
201+
# nicely with StridedViews
200202
# TODO: reconsider this strategy, consider spawning different threads for different blocks
201203

202204
# Copy, adjoint and fill:
203205
function Base.copy!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap)
204206
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
205207
for ((c, bdst), (_, bsrc)) in zip(blocks(tdst), blocks(tsrc))
206-
copy!(StridedView(bdst), StridedView(bsrc))
208+
copy!(bdst, bsrc)
207209
end
208210
return tdst
209211
end

test/cuda/factorizations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ for V in spacelist
150150
end
151151

152152
@testset "Polar decomposition" begin
153-
for T in eltypes,
153+
@testset for T in eltypes,
154154
t in (
155155
CUDA.rand(T, W, W),
156156
CUDA.rand(T, W, W)',
@@ -163,14 +163,14 @@ for V in spacelist
163163
w, p = @constinferred left_polar(t)
164164
@test w * p t
165165
@test isisometric(w)
166-
@test isposdef(project_hermitian!(p))
166+
@test isposdef(p)
167167

168168
w, p = @constinferred left_orth(t; alg = :polar)
169169
@test w * p t
170170
@test isisometric(w)
171171
end
172172

173-
for T in eltypes,
173+
@testset for T in eltypes,
174174
t in (
175175
CUDA.rand(T, W, W),
176176
CUDA.rand(T, W, W)',
@@ -182,7 +182,7 @@ for V in spacelist
182182
p, wᴴ = @constinferred right_polar(t)
183183
@test p * wᴴ t
184184
@test isisometric(wᴴ; side = :right)
185-
@test isposdef(project_hermitian!(p))
185+
@test isposdef(p)
186186

187187
p, wᴴ = @constinferred right_orth(t; alg = :polar)
188188
@test p * wᴴ t

0 commit comments

Comments
 (0)