Skip to content

Commit 8a97299

Browse files
committed
retain AdjointTensorMap in factorizations
1 parent e493d51 commit 8a97299

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

src/tensors/factorizations/adjoint.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
# AdjointTensorMap
22
# ----------------
3+
# map algorithms to their adjoint counterpart
4+
# TODO: this probably belongs in MatrixAlgebraKit
5+
_adjoint(alg::LAPACK_HouseholderQR) = LAPACK_HouseholderLQ(; alg.positive, alg.blocksize)
6+
_adjoint(alg::LAPACK_HouseholderLQ) = LAPACK_HouseholderQR(; alg.positive, alg.blocksize)
7+
_adjoint(alg::LAPACK_HouseholderQL) = LAPACK_HouseholderRQ(; alg.positive, alg.blocksize)
8+
_adjoint(alg::LAPACK_HouseholderRQ) = LAPACK_HouseholderQL(; alg.positive, alg.blocksize)
9+
_adjoint(alg::PolarViaSVD) = PolarViaSVD(_adjoint(alg.svdalg))
10+
_adjoint(alg::AbstractAlgorithm) = alg
11+
312
# 1-arg functions
413
function initialize_output(::typeof(left_null!), t::AdjointTensorMap,
514
alg::AbstractAlgorithm)
6-
return adjoint(initialize_output(right_null!, adjoint(t), alg))
15+
return adjoint(initialize_output(right_null!, adjoint(t), _adjoint(alg)))
716
end
817
function initialize_output(::typeof(right_null!), t::AdjointTensorMap,
918
alg::AbstractAlgorithm)
10-
return adjoint(initialize_output(left_null!, adjoint(t), alg))
19+
return adjoint(initialize_output(left_null!, adjoint(t), _adjoint(alg)))
1120
end
1221

1322
function left_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm)
14-
right_null!(adjoint(t), adjoint(N), alg)
23+
right_null!(adjoint(t), adjoint(N), _adjoint(alg))
1524
return N
1625
end
1726
function right_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm)
18-
left_null!(adjoint(t), adjoint(N), alg)
27+
left_null!(adjoint(t), adjoint(N), _adjoint(alg))
1928
return N
2029
end
2130

@@ -29,40 +38,51 @@ end
2938
# 2-arg functions
3039
for (left_f!, right_f!) in zip((:qr_full!, :qr_compact!, :left_polar!, :left_orth!),
3140
(:lq_full!, :lq_compact!, :right_polar!, :right_orth!))
41+
@eval function copy_input(::typeof($left_f!), t::AdjointTensorMap)
42+
return adjoint(copy_input($right_f!, adjoint(t)))
43+
end
44+
@eval function copy_input(::typeof($right_f!), t::AdjointTensorMap)
45+
return adjoint(copy_input($left_f!, adjoint(t)))
46+
end
47+
3248
@eval function initialize_output(::typeof($left_f!), t::AdjointTensorMap,
3349
alg::AbstractAlgorithm)
34-
return reverse(adjoint.(initialize_output($right_f!, adjoint(t), alg)))
50+
return reverse(adjoint.(initialize_output($right_f!, adjoint(t), _adjoint(alg))))
3551
end
3652
@eval function initialize_output(::typeof($right_f!), t::AdjointTensorMap,
3753
alg::AbstractAlgorithm)
38-
return reverse(adjoint.(initialize_output($left_f!, adjoint(t), alg)))
54+
return reverse(adjoint.(initialize_output($left_f!, adjoint(t), _adjoint(alg))))
3955
end
4056

4157
@eval function $left_f!(t::AdjointTensorMap,
4258
F::Tuple{AdjointTensorMap,AdjointTensorMap},
4359
alg::AbstractAlgorithm)
44-
$right_f!(adjoint(t), reverse(adjoint.(F)), alg)
60+
$right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
4561
return F
4662
end
4763
@eval function $right_f!(t::AdjointTensorMap,
4864
F::Tuple{AdjointTensorMap,AdjointTensorMap},
4965
alg::AbstractAlgorithm)
50-
$left_f!(adjoint(t), reverse(adjoint.(F)), alg)
66+
$left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
5167
return F
5268
end
5369
end
5470

5571
# 3-arg functions
5672
for f! in (:svd_full!, :svd_compact!, :svd_trunc!)
73+
@eval function copy_input(::typeof($f!), t::AdjointTensorMap)
74+
return adjoint(copy_input($f!, adjoint(t)))
75+
end
76+
5777
@eval function initialize_output(::typeof($f!), t::AdjointTensorMap,
5878
alg::AbstractAlgorithm)
59-
return reverse(adjoint.(initialize_output($f!, adjoint(t), alg)))
79+
return reverse(adjoint.(initialize_output($f!, adjoint(t), _adjoint(alg))))
6080
end
6181
_TS = f! === :svd_full! ? :AdjointTensorMap : DiagonalTensorMap
6282
@eval function $f!(t::AdjointTensorMap,
6383
F::Tuple{AdjointTensorMap,$_TS,AdjointTensorMap},
6484
alg::AbstractAlgorithm)
65-
$f!(adjoint(t), reverse(adjoint.(F)), alg)
85+
$f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
6686
return F
6787
end
6888
end

0 commit comments

Comments
 (0)