Skip to content

Commit 0623799

Browse files
authored
Merge pull request #11 from Gertian/master
Expose algorithm kwargs for Grassmann methods
2 parents 8a36139 + 67f0c7f commit 0623799

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

src/TensorKitManifolds.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ checkbase(x, y, z, args...) = checkbase(checkbase(x, y), z, args...)
2828
# the machine epsilon for the elements of an object X, name inspired from eltype
2929
scalareps(X) = eps(real(scalartype(X)))
3030

31+
# default SVD algorithm used in the algorithms
32+
default_svd_alg(::AbstractTensorMap) = TensorKit.SVD()
33+
3134
function isisometry(W::AbstractTensorMap; tol=10 * scalareps(W))
3235
WdW = W' * W
3336
s = zero(float(real(scalartype(W))))
@@ -61,7 +64,7 @@ end
6164

6265
struct PolarNewton <: TensorKit.OrthogonalFactorizationAlgorithm
6366
end
64-
function projectisometric!(W::AbstractTensorMap; alg=Polar())
67+
function projectisometric!(W::AbstractTensorMap; alg=default_svd_alg(W))
6568
if alg isa TensorKit.Polar || alg isa TensorKit.SDD
6669
foreach(blocks(W)) do (c, b)
6770
return _polarsdd!(b)
@@ -98,7 +101,7 @@ projecthermitian(W::AbstractTensorMap) = projecthermitian!(copy(W))
98101
projectantihermitian(W::AbstractTensorMap) = projectantihermitian!(copy(W))
99102

100103
function projectisometric(W::AbstractTensorMap;
101-
alg::TensorKit.OrthogonalFactorizationAlgorithm=Polar())
104+
alg=default_svd_alg(W))
102105
return projectisometric!(copy(W); alg=alg)
103106
end
104107
function projectcomplement(X::AbstractTensorMap, W::AbstractTensorMap,

src/grassmann.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ module Grassmann
77
using TensorKit
88
using TensorKit: similarstoragetype, SectorDict
99
using ..TensorKitManifolds: projecthermitian!, projectantihermitian!,
10-
projectisometric!, projectcomplement!, PolarNewton
10+
projectisometric!, projectcomplement!, PolarNewton,
11+
default_svd_alg
1112
import ..TensorKitManifolds: base, checkbase, inner, retract, transport, transport!
1213

1314
# special type to store tangent vectors using Z
@@ -56,7 +57,7 @@ function Base.getproperty(Δ::GrassmannTangent, sym::Symbol)
5657
elseif sym (:U, :S, :V)
5758
v = Base.getfield(Δ, sym)
5859
v !== nothing && return v
59-
U, S, V, = tsvd(Δ.Z)
60+
U, S, V, = tsvd(Δ.Z; alg=default_svd_alg(Δ.Z))
6061
Base.setfield!(Δ, :U, U)
6162
Base.setfield!(Δ, :S, S)
6263
Base.setfield!(Δ, :V, V)
@@ -191,15 +192,15 @@ for the isometries `U`, `V`, and `Y`, and the diagonal matrix `S`, and returning
191192
`Z = U * S * V` and `Y`.
192193
"""
193194
function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=nothing)
194-
space(Wold) == space(Wnew) || throw(SectorMismatch())
195+
space(Wold) == space(Wnew) || throw(SpaceMismatch())
195196
WodWn = Wold' * Wnew # V' * cos(S) * V * Y
196197
Wneworth = Wnew - Wold * WodWn
197-
Vd, cS, VY = tsvd!(WodWn)
198+
Vd, cS, VY = tsvd!(WodWn; alg=default_svd_alg(WodWn))
198199
Scmplx = acos(cS)
199200
# acos always returns a complex TensorMap. We cast back to real if possible.
200201
S = scalartype(WodWn) <: Real && isreal(sectortype(Scmplx)) ? real(Scmplx) : Scmplx
201202
UsS = Wneworth * VY' # U * sin(S) # should be in polar decomposition form
202-
U = projectisometric!(UsS; alg=Polar())
203+
U = projectisometric!(UsS)
203204
Y = Vd * VY
204205
V = Vd'
205206
Z = Grassmann.GrassmannTangent(Wold, U * S * V)
@@ -213,9 +214,9 @@ Return the unitary Y such that V*Y and W are "in the same Grassmann gauge" (tech
213214
from fibre bundles: in the same section), such that they can be related by a Grassmann
214215
retraction.
215216
"""
216-
function relativegauge(W::AbstractTensorMap, V::AbstractTensorMap)
217-
space(W) == space(V) || throw(SectorMismatch())
218-
return projectisometric!(V' * W; alg=Polar())
217+
function relativegauge(W::AbstractTensorMap, V::AbstractTensorMap; alg=nothing)
218+
space(W) == space(V) || throw(SpaceMismatch())
219+
return projectisometric!(V' * W)
219220
end
220221
221222
function transport!(Θ::GrassmannTangent, W::AbstractTensorMap, Δ::GrassmannTangent, α, W′;

0 commit comments

Comments
 (0)