Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ const svd_rrule_tol = ctmrg_tol
const svd_rrule_min_krylovdim = 48
const svd_rrule_verbosity = -1
const svd_rrule_alg = :full # ∈ {:full, :gmres, :bicgstab, :arnoldi}
const svd_rrule_broadening = 1.0e-13
const svd_rrule_broadening = 1.0e-10
const krylovdim_factor = 1.4

# Projectors
Expand Down
12 changes: 7 additions & 5 deletions src/utility/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,11 @@ function ChainRulesCore.rrule(
U, S, V⁺ = info.U_full, info.S_full, info.V_full # untruncated SVD decomposition

smallest_sval = minimum(((_, b),) -> minimum(diag(b)), blocks(S̃))
tol_scale = norm(S, Inf)
pullback_tol = clamp(
smallest_sval, eps(scalartype(S̃))^(3 / 4), eps(scalartype(S̃))^(1 / 2)
1.0e-2 * smallest_sval, tol_scale * eps(scalartype(S̃))^(3 / 4), tol_scale * eps(scalartype(S̃))^(1 / 2)
)
broadening = tol_scale * alg.rrule_alg.broadening

function svd_trunc!_full_pullback(ΔUSVi)
ΔU, ΔS, ΔV⁺, = unthunk.(ΔUSVi)
Expand All @@ -367,7 +369,7 @@ function ChainRulesCore.rrule(
ΔSdc,
ΔV⁺c;
tol = pullback_tol,
broadening = alg.rrule_alg.broadening,
broadening = broadening,
verbosity = alg.rrule_alg.verbosity,
)
end
Expand All @@ -392,7 +394,7 @@ function ChainRulesCore.rrule(
# update rrule_alg tolerance to be compatible with smallest singular value
rrule_alg = alg.rrule_alg
smallest_sval = minimum(((_, b),) -> minimum(diag(b)), blocks(S))
proper_tol = clamp(rrule_alg.tol, eps(scalartype(S))^(3 / 4), 1.0e-2 * smallest_sval)
proper_tol = clamp(rrule_alg.tol, _default_pullback_gaugetol(S), 1.0e-2 * smallest_sval)
rrule_alg = @set rrule_alg.tol = proper_tol

function svd_trunc!_itersvd_pullback(ΔUSVi)
Expand Down Expand Up @@ -470,12 +472,12 @@ end
# Lorentzian broadening for divergent term in SVD rrule, see
# https://journals.aps.org/prresearch/abstract/10.1103/PhysRevResearch.7.013237
function _lorentz_broaden(x, ε = eps(real(scalartype(x)))^(3 / 4))
return x / (x^2 + ε)
return x / (x^2 + ε^2)
end

function _default_pullback_gaugetol(x)
n = norm(x, Inf)
return eps(eltype(n))^(3 / 4) * max(n, one(n))
return eps(eltype(n))^(3 / 4) * n
end

# SVD_pullback: pullback implementation for general (possibly truncated) SVD
Expand Down
4 changes: 2 additions & 2 deletions test/utility/svd_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ end
r_degen = u * s * v

no_broadening_no_cutoff_alg = @set full_alg.rrule_alg.broadening = 1.0e-30
small_broadening_alg = @set full_alg.rrule_alg.broadening = 1.0e-13
small_broadening_alg = @set full_alg.rrule_alg.broadening = 1.0e-10

l_only_cutoff, g_only_cutoff = withgradient(
A -> lossfun(A, full_alg, R, trunc), r_degen
Expand Down Expand Up @@ -101,7 +101,7 @@ end
symm_r_degen = u * s * v

no_broadening_no_cutoff_alg = @set full_alg.rrule_alg.broadening = 1.0e-30
small_broadening_alg = @set full_alg.rrule_alg.broadening = 1.0e-13
small_broadening_alg = @set full_alg.rrule_alg.broadening = 1.0e-10

l_only_cutoff, g_only_cutoff = withgradient(
A -> lossfun(A, full_alg, symm_R, symm_trspace), symm_r_degen
Expand Down