|
1 | 1 | # AdjointTensorMap |
2 | 2 | # ---------------- |
| 3 | +# map algorithms to their adjoint counterpart |
| 4 | +# TODO: this probably belongs in MatrixAlgebraKit |
| 5 | +_adjoint(alg::LAPACK_HouseholderQR) = LAPACK_HouseholderLQ(; alg.positive, alg.blocksize) |
| 6 | +_adjoint(alg::LAPACK_HouseholderLQ) = LAPACK_HouseholderQR(; alg.positive, alg.blocksize) |
| 7 | +_adjoint(alg::LAPACK_HouseholderQL) = LAPACK_HouseholderRQ(; alg.positive, alg.blocksize) |
| 8 | +_adjoint(alg::LAPACK_HouseholderRQ) = LAPACK_HouseholderQL(; alg.positive, alg.blocksize) |
| 9 | +_adjoint(alg::PolarViaSVD) = PolarViaSVD(_adjoint(alg.svdalg)) |
| 10 | +_adjoint(alg::AbstractAlgorithm) = alg |
| 11 | + |
3 | 12 | # 1-arg functions |
4 | 13 | function initialize_output(::typeof(left_null!), t::AdjointTensorMap, |
5 | 14 | alg::AbstractAlgorithm) |
6 | | - return adjoint(initialize_output(right_null!, adjoint(t), alg)) |
| 15 | + return adjoint(initialize_output(right_null!, adjoint(t), _adjoint(alg))) |
7 | 16 | end |
8 | 17 | function initialize_output(::typeof(right_null!), t::AdjointTensorMap, |
9 | 18 | alg::AbstractAlgorithm) |
10 | | - return adjoint(initialize_output(left_null!, adjoint(t), alg)) |
| 19 | + return adjoint(initialize_output(left_null!, adjoint(t), _adjoint(alg))) |
11 | 20 | end |
12 | 21 |
|
13 | 22 | function left_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm) |
14 | | - right_null!(adjoint(t), adjoint(N), alg) |
| 23 | + right_null!(adjoint(t), adjoint(N), _adjoint(alg)) |
15 | 24 | return N |
16 | 25 | end |
17 | 26 | function right_null!(t::AdjointTensorMap, N::AdjointTensorMap, alg::AbstractAlgorithm) |
18 | | - left_null!(adjoint(t), adjoint(N), alg) |
| 27 | + left_null!(adjoint(t), adjoint(N), _adjoint(alg)) |
19 | 28 | return N |
20 | 29 | end |
21 | 30 |
|
|
29 | 38 | # 2-arg functions |
30 | 39 | for (left_f!, right_f!) in zip((:qr_full!, :qr_compact!, :left_polar!, :left_orth!), |
31 | 40 | (:lq_full!, :lq_compact!, :right_polar!, :right_orth!)) |
| 41 | + @eval function copy_input(::typeof($left_f!), t::AdjointTensorMap) |
| 42 | + return adjoint(copy_input($right_f!, adjoint(t))) |
| 43 | + end |
| 44 | + @eval function copy_input(::typeof($right_f!), t::AdjointTensorMap) |
| 45 | + return adjoint(copy_input($left_f!, adjoint(t))) |
| 46 | + end |
| 47 | + |
32 | 48 | @eval function initialize_output(::typeof($left_f!), t::AdjointTensorMap, |
33 | 49 | alg::AbstractAlgorithm) |
34 | | - return reverse(adjoint.(initialize_output($right_f!, adjoint(t), alg))) |
| 50 | + return reverse(adjoint.(initialize_output($right_f!, adjoint(t), _adjoint(alg)))) |
35 | 51 | end |
36 | 52 | @eval function initialize_output(::typeof($right_f!), t::AdjointTensorMap, |
37 | 53 | alg::AbstractAlgorithm) |
38 | | - return reverse(adjoint.(initialize_output($left_f!, adjoint(t), alg))) |
| 54 | + return reverse(adjoint.(initialize_output($left_f!, adjoint(t), _adjoint(alg)))) |
39 | 55 | end |
40 | 56 |
|
41 | 57 | @eval function $left_f!(t::AdjointTensorMap, |
42 | 58 | F::Tuple{AdjointTensorMap,AdjointTensorMap}, |
43 | 59 | alg::AbstractAlgorithm) |
44 | | - $right_f!(adjoint(t), reverse(adjoint.(F)), alg) |
| 60 | + $right_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) |
45 | 61 | return F |
46 | 62 | end |
47 | 63 | @eval function $right_f!(t::AdjointTensorMap, |
48 | 64 | F::Tuple{AdjointTensorMap,AdjointTensorMap}, |
49 | 65 | alg::AbstractAlgorithm) |
50 | | - $left_f!(adjoint(t), reverse(adjoint.(F)), alg) |
| 66 | + $left_f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) |
51 | 67 | return F |
52 | 68 | end |
53 | 69 | end |
54 | 70 |
|
55 | 71 | # 3-arg functions |
56 | 72 | for f! in (:svd_full!, :svd_compact!, :svd_trunc!) |
| 73 | + @eval function copy_input(::typeof($f!), t::AdjointTensorMap) |
| 74 | + return adjoint(copy_input($f!, adjoint(t))) |
| 75 | + end |
| 76 | + |
57 | 77 | @eval function initialize_output(::typeof($f!), t::AdjointTensorMap, |
58 | 78 | alg::AbstractAlgorithm) |
59 | | - return reverse(adjoint.(initialize_output($f!, adjoint(t), alg))) |
| 79 | + return reverse(adjoint.(initialize_output($f!, adjoint(t), _adjoint(alg)))) |
60 | 80 | end |
61 | 81 | _TS = f! === :svd_full! ? :AdjointTensorMap : DiagonalTensorMap |
62 | 82 | @eval function $f!(t::AdjointTensorMap, |
63 | 83 | F::Tuple{AdjointTensorMap,$_TS,AdjointTensorMap}, |
64 | 84 | alg::AbstractAlgorithm) |
65 | | - $f!(adjoint(t), reverse(adjoint.(F)), alg) |
| 85 | + $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg)) |
66 | 86 | return F |
67 | 87 | end |
68 | 88 | end |
|
0 commit comments