Skip to content

Commit 7838bc8

Browse files
committed
clean up handling AdjointTensorMap
1 parent 113f9a1 commit 7838bc8

File tree

3 files changed

+94
-41
lines changed

3 files changed

+94
-41
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# AdjointTensorMap
2+
# ----------------
3+
# 1-arg functions
4+
function initialize_output(::typeof(left_null!), t::AdjointTensorMap,
5+
alg::AbstractAlgorithm)
6+
return adjoint(initialize_output(right_null!, adjoint(t), alg))
7+
end
8+
function initialize_output(::typeof(right_null!), t::AdjointTensorMap,
9+
alg::AbstractAlgorithm)
10+
return adjoint(initialize_output(left_null!, adjoint(t), alg))
11+
end
12+
13+
function left_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm)
14+
right_null!(adjoint(t), adjoint(N), alg)
15+
return N
16+
end
17+
function right_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm)
18+
left_null!(adjoint(t), adjoint(N), alg)
19+
return N
20+
end
21+
22+
# 2-arg functions
23+
for (left_f!, right_f!) in zip((:qr_full!, :qr_compact!, :left_polar!, :left_orth!),
24+
(:lq_full!, :lq_compact!, :right_polar!, :right_orth!))
25+
@eval function initialize_output(::typeof($left_f!), t::AdjointTensorMap,
26+
alg::AbstractAlgorithm)
27+
return reverse(adjoint.(initialize_output($right_f!, adjoint(t), alg)))
28+
end
29+
@eval function initialize_output(::typeof($right_f!), t::AdjointTensorMap,
30+
alg::AbstractAlgorithm)
31+
return reverse(adjoint.(initialize_output($left_f!, adjoint(t), alg)))
32+
end
33+
34+
@eval function $left_f!(t::AdjointTensorMap,
35+
F::Tuple{AdjointTensorMap,AdjointTensorMap},
36+
alg::AbstractAlgorithm)
37+
$right_f!(adjoint(t), reverse(adjoint.(F)), alg)
38+
return F
39+
end
40+
@eval function $right_f!(t::AdjointTensorMap,
41+
F::Tuple{AdjointTensorMap,AdjointTensorMap},
42+
alg::AbstractAlgorithm)
43+
$left_f!(adjoint(t), reverse(adjoint.(F)), alg)
44+
return F
45+
end
46+
end
47+
48+
# 3-arg functions
49+
for f! in (:svd_full!, :svd_compact!, :svd_trunc!)
50+
@eval function initialize_output(::typeof($f!), t::AdjointTensorMap,
51+
alg::AbstractAlgorithm)
52+
return reverse(adjoint.(initialize_output($f!, adjoint(t), alg)))
53+
end
54+
_TS = f! === :svd_full! ? :AdjointTensorMap : DiagonalTensorMap
55+
@eval function $f!(t::AdjointTensorMap,
56+
F::Tuple{AdjointTensorMap,$_TS,AdjointTensorMap},
57+
alg::AbstractAlgorithm)
58+
$f!(adjoint(t), reverse(adjoint.(F)), alg)
59+
return F
60+
end
61+
end
62+
63+
function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos())
64+
InnerProductStyle(t) === EuclideanInnerProduct() ||
65+
throw_invalid_innerproduct(:leftorth!)
66+
return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg')))
67+
end
68+
69+
function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos())
70+
InnerProductStyle(t) === EuclideanInnerProduct() ||
71+
throw_invalid_innerproduct(:rightorth!)
72+
return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg')))
73+
end
74+
75+
function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...)
76+
InnerProductStyle(t) === EuclideanInnerProduct() ||
77+
throw_invalid_innerproduct(:leftnull!)
78+
return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...))
79+
end
80+
81+
function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...)
82+
InnerProductStyle(t) === EuclideanInnerProduct() ||
83+
throw_invalid_innerproduct(:rightnull!)
84+
return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...))
85+
end
86+
87+
function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
88+
u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg)
89+
return adjoint(vt), adjoint(s), adjoint(u), err
90+
end

src/tensors/factorizations/factorizations.jl

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ include("implementations.jl")
4040
include("matrixalgebrakit.jl")
4141
include("truncation.jl")
4242
include("deprecations.jl")
43+
include("adjoint.jl")
4344

4445
TensorKit.one!(A::AbstractMatrix) = MatrixAlgebraKit.one!(A)
4546

@@ -54,37 +55,6 @@ end
5455
#------------------------------------------------------------------------------------------
5556
const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}
5657

57-
# AdjointTensorMap
58-
# ----------------
59-
function leftorth!(t::AdjointTensorMap; alg::OFA=QRpos())
60-
InnerProductStyle(t) === EuclideanInnerProduct() ||
61-
throw_invalid_innerproduct(:leftorth!)
62-
return map(adjoint, reverse(rightorth!(adjoint(t); alg=alg')))
63-
end
64-
65-
function rightorth!(t::AdjointTensorMap; alg::OFA=LQpos())
66-
InnerProductStyle(t) === EuclideanInnerProduct() ||
67-
throw_invalid_innerproduct(:rightorth!)
68-
return map(adjoint, reverse(leftorth!(adjoint(t); alg=alg')))
69-
end
70-
71-
function leftnull!(t::AdjointTensorMap; alg::OFA=QR(), kwargs...)
72-
InnerProductStyle(t) === EuclideanInnerProduct() ||
73-
throw_invalid_innerproduct(:leftnull!)
74-
return adjoint(rightnull!(adjoint(t); alg=alg', kwargs...))
75-
end
76-
77-
function rightnull!(t::AdjointTensorMap; alg::OFA=LQ(), kwargs...)
78-
InnerProductStyle(t) === EuclideanInnerProduct() ||
79-
throw_invalid_innerproduct(:rightnull!)
80-
return adjoint(leftnull!(adjoint(t); alg=alg', kwargs...))
81-
end
82-
83-
function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
84-
u, s, vt, err = tsvd!(adjoint(t); trunc=trunc, p=p, alg=alg)
85-
return adjoint(vt), adjoint(s), adjoint(u), err
86-
end
87-
8858
# DiagonalTensorMap
8959
# -----------------
9060
function leftorth!(d::DiagonalTensorMap; alg=QR(), kwargs...)

src/tensors/factorizations/matrixalgebrakit.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -556,15 +556,8 @@ function initialize_output(::typeof(right_null!), t::AbstractTensorMap)
556556
return N
557557
end
558558

559-
for f! in (:left_null_svd!, :right_null_svd!)
560-
@eval function $f!(t::AbstractTensorMap, N, alg, ::Nothing=nothing)
561-
foreachblock(t, N) do _, (b, n)
562-
n′ = $f!(b, n, alg)
563-
# deal with the case where the output is not the same as the input
564-
n === n′ || copyto!(n, n′)
565-
return nothing
566-
end
567-
568-
return N
559+
for (f!, f_svd!) in zip((:left_null!, :right_null!), (:left_null_svd!, :right_null_svd!))
560+
@eval function $f_svd!(t::AbstractTensorMap, N, alg, ::Nothing=nothing)
561+
return $f!(t, N, alg)
569562
end
570563
end

0 commit comments

Comments
 (0)