Skip to content

Commit e5ee802

Browse files
committed
fix AD tests
1 parent 19cce9d commit e5ee802

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function ChainRulesCore.rrule(::typeof(LinearAlgebra.eigvals!), t::AbstractTenso
6565
end
6666

6767
function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
68-
alg isa TensorKit.QR || alg isa TensorKit.QRpos ||
68+
alg isa MatrixAlgebraKit.LAPACK_HouseholderQR ||
6969
error("only `alg=QR()` and `alg=QRpos()` are supported")
7070
QR = leftorth(t; alg)
7171
function leftorth!_pullback(ΔQR′)
@@ -85,7 +85,7 @@ function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRp
8585
end
8686

8787
function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
88-
alg isa TensorKit.LQ || alg isa TensorKit.LQpos ||
88+
alg isa MatrixAlgebraKit.LAPACK_HouseholderLQ ||
8989
error("only `alg=LQ()` and `alg=LQpos()` are supported")
9090
LQ = rightorth(t; alg)
9191
function rightorth!_pullback(ΔLQ′)

src/TensorKit.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ export left_orth, right_orth, left_null, right_null,
8282
eigh_vals!, eigh_vals, eig_vals!, eig_vals,
8383
isposdef, isposdef!, ishermitian, isisometry, isunitary, sylvester, rank, cond
8484

85-
# deprecate:
86-
export eig, eig!, eigh, eigh!, eigen, eigen!, tsvd, tsvd!, leftorth, leftorth!, rightorth,
87-
rightorth!, leftnull, leftnull!, rightnull, rightnull!
8885
export braid, braid!, permute, permute!, transpose, transpose!, twist, twist!, repartition,
8986
repartition!
9087
export catdomain, catcodomain, absorb, absorb!

src/tensors/factorizations/matrixalgebrakit.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -611,11 +611,17 @@ function left_orth!(t::AbstractTensorMap;
611611
trunc == notrunc() || kind === :svd ||
612612
throw(ArgumentError("truncation not supported for left_orth with kind = $kind"))
613613

614-
kind === :qr && return qr_compact!(t; alg_qr...)
615-
kind === :polar && return left_orth_polar!(t; alg_polar...)
616-
kind === :svd && return left_orth_svd!(t; trunc, alg_svd...)
617-
618-
throw(ArgumentError(lazy"`left_orth!` received unknown value `kind = $kind`"))
614+
return if kind === :qr
615+
alg_qr isa NamedTuple ? qr_compact!(t; alg_qr...) : qr_compact!(t; alg=alg_qr)
616+
elseif kind === :polar
617+
alg_polar isa NamedTuple ? left_orth_polar!(t; alg_polar...) :
618+
left_orth_polar!(t; alg=alg_polar)
619+
elseif kind === :svd
620+
alg_svd isa NamedTuple ? left_orth_svd!(t; trunc, alg_svd...) :
621+
left_orth_svd!(t; trunc, alg=alg_svd)
622+
else
623+
throw(ArgumentError(lazy"`left_orth!` received unknown value `kind = $kind`"))
624+
end
619625
end
620626
function right_orth!(t::AbstractTensorMap;
621627
trunc::TruncationStrategy=notrunc(),
@@ -624,11 +630,17 @@ function right_orth!(t::AbstractTensorMap;
624630
trunc == notrunc() || kind === :svd ||
625631
throw(ArgumentError("truncation not supported for right_orth with kind = $kind"))
626632

627-
kind === :lq && return lq_compact!(t; alg_lq...)
628-
kind === :polar && return right_orth_polar!(t; alg_polar...)
629-
kind === :svd && return right_orth_svd!(t; trunc, alg_svd...)
630-
631-
throw(ArgumentError(lazy"`right_orth!` received unknown value `kind = $kind`"))
633+
return if kind === :lq
634+
alg_lq isa NamedTuple ? lq_compact!(t; alg_lq...) : lq_compact!(t; alg=alg_lq)
635+
elseif kind === :polar
636+
alg_polar isa NamedTuple ? right_orth_polar!(t; alg_polar...) :
637+
right_orth_polar!(t; alg=alg_polar)
638+
elseif kind === :svd
639+
alg_svd isa NamedTuple ? right_orth_svd!(t; trunc, alg_svd...) :
640+
right_orth_svd!(t; trunc, alg=alg_svd)
641+
else
642+
throw(ArgumentError(lazy"`right_orth!` received unknown value `kind = $kind`"))
643+
end
632644
end
633645

634646
function left_orth_polar!(t::AbstractTensorMap; alg=nothing, kwargs...)

0 commit comments

Comments
 (0)