Skip to content

Commit ba538e7

Browse files
committed
Working adjoint
1 parent a918790 commit ba538e7

File tree

3 files changed

+32
-21
lines changed

3 files changed

+32
-21
lines changed

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,11 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth)
164164
return tf
165165
end
166166
end
167+
168+
function Base.copy!(tdst::CuTensorMap, tsrc::AdjointTensorMap)
169+
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
170+
for ((c, bdst), (_, bsrc)) in zip(blocks(tdst), blocks(tsrc))
171+
copy!(bdst, bsrc)
172+
end
173+
return tdst
174+
end

src/factorizations/adjoint.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ _adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs.
99
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svd_alg))
1010
_adjoint(alg::AbstractAlgorithm) = alg
1111

12+
_adjoint(alg::MAK.CUSOLVER_HouseholderQR) = MAK.LQViaTransposedQR(alg)
13+
_adjoint(alg::MAK.LQViaTransposedQR) = alg.qr_alg
14+
1215
for f in
1316
[
1417
:svd_compact, :svd_full, :svd_vals,

test/cuda/factorizations.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ for V in spacelist
4949
for T in eltypes,
5050
t in (
5151
CUDA.rand(T, W, W),
52-
#CUDA.rand(T, W, W)',
52+
CUDA.rand(T, W, W)',
5353
CUDA.rand(T, W, V4),
54-
#CUDA.rand(T, V4, W)',
54+
CUDA.rand(T, V4, W)',
5555
DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1),
5656
)
5757

@@ -105,9 +105,9 @@ for V in spacelist
105105
for T in eltypes,
106106
t in (
107107
CUDA.rand(T, W, W),
108-
#CUDA.rand(T, W, W)',
108+
CUDA.rand(T, W, W)',
109109
CUDA.rand(T, W, V4),
110-
#CUDA.rand(T, V4, W)',
110+
CUDA.rand(T, V4, W)',
111111
DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1),
112112
)
113113

@@ -157,17 +157,17 @@ for V in spacelist
157157
for T in eltypes,
158158
t in (
159159
CUDA.rand(T, W, W),
160-
#CUDA.rand(T, W, W)',
160+
CUDA.rand(T, W, W)',
161161
CUDA.rand(T, W, V4),
162-
#CUDA.rand(T, V4, W)',
162+
CUDA.rand(T, V4, W)',
163163
DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1),
164164
)
165165

166166
@assert domain(t) codomain(t)
167167
w, p = @constinferred left_polar(t)
168168
@test w * p t
169169
@test isisometric(w)
170-
@test isposdef(p)
170+
@test isposdef(project_hermitian!(p))
171171

172172
w, p = @constinferred left_orth(t; alg = :polar)
173173
@test w * p t
@@ -177,16 +177,16 @@ for V in spacelist
177177
for T in eltypes,
178178
t in (
179179
CUDA.rand(T, W, W),
180-
#CUDA.rand(T, W, W)',
180+
CUDA.rand(T, W, W)',
181181
CUDA.rand(T, V4, W),
182-
#CUDA.rand(T, W, V4)',
182+
CUDA.rand(T, W, V4)',
183183
)
184184

185185
@assert codomain(t) domain(t)
186186
p, wᴴ = @constinferred right_polar(t)
187187
@test p * wᴴ t
188188
@test isisometric(wᴴ; side = :right)
189-
@test isposdef(p)
189+
@test isposdef(project_hermitian!(p))
190190

191191
p, wᴴ = @constinferred right_orth(t; alg = :polar)
192192
@test p * wᴴ t
@@ -198,11 +198,11 @@ for V in spacelist
198198
for T in eltypes,
199199
t in (
200200
CUDA.rand(T, W, W),
201-
#CUDA.rand(T, W, W)',
201+
CUDA.rand(T, W, W)',
202202
CUDA.rand(T, W, V4),
203203
CUDA.rand(T, V4, W),
204-
#CUDA.rand(T, W, V4)',
205-
#CUDA.rand(T, V4, W)',
204+
CUDA.rand(T, W, V4)',
205+
CUDA.rand(T, V4, W)',
206206
DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1),
207207
)
208208

@@ -338,7 +338,7 @@ for V in spacelist
338338
t in (
339339
CUDA.rand(T, V1, V1),
340340
CUDA.rand(T, W, W),
341-
#CUDA.rand(T, W, W)',
341+
CUDA.rand(T, W, W)',
342342
DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1),
343343
)
344344

@@ -395,11 +395,11 @@ for V in spacelist
395395
for T in eltypes,
396396
t in (
397397
CUDA.rand(T, W, W),
398-
#CUDA.rand(T, W, W)',
398+
CUDA.rand(T, W, W)',
399399
CUDA.rand(T, W, V4),
400400
CUDA.rand(T, V4, W),
401-
#CUDA.rand(T, W, V4)',
402-
#CUDA.rand(T, V4, W)',
401+
CUDA.rand(T, W, V4)',
402+
CUDA.rand(T, V4, W)',
403403
DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1),
404404
)
405405

@@ -425,7 +425,7 @@ for V in spacelist
425425
end
426426
for T in eltypes, t in (
427427
CUDA.rand(T, W, W),
428-
#CUDA.rand(T, W, W)',
428+
CUDA.rand(T, W, W)',
429429
)
430430
project_hermitian!(t)
431431
vals = @constinferred LinearAlgebra.eigvals(t)
@@ -440,7 +440,7 @@ for V in spacelist
440440
t in (
441441
CUDA.rand(T, V1, V1),
442442
CUDA.rand(T, W, W),
443-
#CUDA.rand(T, W, W)',
443+
CUDA.rand(T, W, W)',
444444
DiagonalTensorMap(CUDA.rand(T, reduceddim(V1)), V1),
445445
)
446446
normalize!(t)
@@ -472,9 +472,9 @@ for V in spacelist
472472
for T in eltypes,
473473
t in (
474474
CUDA.randn(T, W, W),
475-
#CUDA.randn(T, W, W)',
475+
CUDA.randn(T, W, W)',
476476
CUDA.randn(T, W, V4),
477-
#CUDA.randn(T, V4, W)',
477+
CUDA.randn(T, V4, W)',
478478
)
479479
t2 = project_isometric(t)
480480
@test isisometric(t2)

0 commit comments

Comments
 (0)