Skip to content

Commit f623bc3

Browse files
committed
improve adjoint support
1 parent 7d20fc3 commit f623bc3

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

src/factorizations/adjoint.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,21 @@ _adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs.
99
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svdalg))
1010
_adjoint(alg::AbstractAlgorithm) = alg
1111

12+
for f in
13+
[
14+
:svd_compact, :svd_full, :svd_trunc, :svd_vals, :qr_compact, :qr_full, :qr_null,
15+
:lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full,
16+
:eigh_trunc, :eigh_vals, :left_polar, :right_polar,
17+
:project_hermitian, :project_antihermitian, :project_isometric,
18+
]
19+
f! = Symbol(f, :!)
20+
# just return the algorithm for the parent type since we are mapping this with
21+
# `_adjoint` afterwards anyways.
22+
# TODO: properly handle these cases
23+
@eval MAK.default_algorithm(::typeof($f!), ::Type{T}; kwargs...) where {T <: AdjointTensorMap} =
24+
MAK.default_algorithm($f!, TensorKit.parenttype(T); kwargs...)
25+
end
26+
1227
# 1-arg functions
1328
function MAK.initialize_output(::typeof(left_null!), t::AdjointTensorMap, alg::AbstractAlgorithm)
1429
return adjoint(MAK.initialize_output(right_null!, adjoint(t), _adjoint(alg)))
@@ -38,8 +53,8 @@ end
3853

3954
# 2-arg functions
4055
for (left_f!, right_f!) in zip(
41-
(:qr_full!, :qr_compact!, :left_polar!, :left_orth!),
42-
(:lq_full!, :lq_compact!, :right_polar!, :right_orth!)
56+
(:qr_full!, :qr_compact!, :left_polar!),
57+
(:lq_full!, :lq_compact!, :right_polar!)
4358
)
4459
@eval function MAK.copy_input(::typeof($left_f!), t::AdjointTensorMap)
4560
return adjoint(MAK.copy_input($right_f!, adjoint(t)))
@@ -70,8 +85,8 @@ for (left_f!, right_f!) in zip(
7085
end
7186

7287
for (left_f, right_f) in zip(
73-
(:qr_full, :qr_compact, :left_polar, :left_orth),
74-
(:lq_full, :lq_compact, :right_polar, :right_orth)
88+
(:qr_full, :qr_compact, :left_polar),
89+
(:lq_full, :lq_compact, :right_polar)
7590
)
7691
@eval function MAK.$left_f(t::AdjointTensorMap; kwargs...)
7792
return reverse(adjoint.($right_f(adjoint(t); kwargs...)))

src/tensors/adjoint.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ struct AdjointTensorMap{T, S, N₁, N₂, TT <: AbstractTensorMap{T, S, N₂, N
1111
parent::TT
1212
end
1313
Base.parent(t::AdjointTensorMap) = t.parent
14+
parenttype(t::AdjointTensorMap) = parenttype(typeof(t))
15+
parenttype(::Type{AdjointTensorMap{T, S, N₁, N₂, TT}}) where {T, S, N₁, N₂, TT} = TT
1416

1517
# Constructor: construct from taking adjoint of a tensor
1618
Base.adjoint(t::AdjointTensorMap) = parent(t)

0 commit comments

Comments
 (0)