Skip to content

Commit 0238f05

Browse files
authored
Adapt to TensorKit v0.15 (#17)
* Adapt to v0.15 * format and bugfix * remove stray `tsvd`
1 parent fdc9112 commit 0238f05

File tree

6 files changed

+57
-85
lines changed

6 files changed

+57
-85
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
name = "TensorKitManifolds"
22
uuid = "11fa318c-39cb-4a83-b1ed-cdc7ba1e3684"
33
authors = ["Jutho Haegeman <[email protected]>", "Markus Hauru <[email protected]>"]
4-
version = "0.7.2"
4+
version = "0.7.3"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
89
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
910

1011
[compat]
11-
TensorKit = "0.13,0.14"
12+
MatrixAlgebraKit = "0.5.0"
13+
TensorKit = "0.15"
1214
julia = "1.10"
1315

1416
[extras]

src/TensorKitManifolds.jl

Lines changed: 15 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ export Grassmann, Stiefel, Unitary
77
export inner, retract, transport, transport!
88

99
using TensorKit
10+
using MatrixAlgebraKit: MatrixAlgebraKit, AbstractAlgorithm, Algorithm, PolarViaSVD,
11+
LAPACK_DivideAndConquer, diagview
12+
import MatrixAlgebraKit as MAK
1013

1114
# Every submodule -- Grassmann, Stiefel, and Unitary -- implements their own methods for
1215
# these. The signatures should be
@@ -28,23 +31,8 @@ checkbase(x, y, z, args...) = checkbase(checkbase(x, y), z, args...)
2831
# the machine epsilon for the elements of an object X, name inspired from eltype
2932
scalareps(X) = eps(real(scalartype(X)))
3033

31-
# default SVD algorithm used in the algorithms
32-
default_svd_alg(::AbstractTensorMap) = TensorKit.SVD()
33-
34-
function isisometry(W::AbstractTensorMap; tol=10 * scalareps(W))
35-
WdW = W' * W
36-
s = zero(float(real(scalartype(W))))
37-
for (c, b) in blocks(WdW)
38-
_subtractone!(b)
39-
s += dim(c) * length(b)
40-
end
41-
return norm(WdW) <= tol * sqrt(s)
42-
end
43-
44-
function isunitary(W::AbstractTensorMap; tol=10 * scalareps(W))
45-
return isisometry(W; tol=tol) && isisometry(W'; tol=tol)
46-
end
47-
34+
# TODO: these functions should be replaced by MAK functions
35+
projecthermitian(W::AbstractTensorMap) = projecthermitian!(copy(W))
4836
function projecthermitian!(W::AbstractTensorMap)
4937
codomain(W) == domain(W) ||
5038
throw(DomainError("Tensor with distinct domain and codomain cannot be hermitian."))
@@ -53,6 +41,8 @@ function projecthermitian!(W::AbstractTensorMap)
5341
end
5442
return W
5543
end
44+
45+
projectantihermitian(W::AbstractTensorMap) = projectantihermitian!(copy(W))
5646
function projectantihermitian!(W::AbstractTensorMap)
5747
codomain(W) == domain(W) ||
5848
throw(DomainError("Tensor with distinct domain and codomain cannot be anithermitian."))
@@ -62,27 +52,18 @@ function projectantihermitian!(W::AbstractTensorMap)
6252
return W
6353
end
6454

65-
struct PolarNewton <: TensorKit.OrthogonalFactorizationAlgorithm
66-
end
67-
function projectisometric!(W::AbstractTensorMap; alg=default_svd_alg(W))
68-
if alg isa TensorKit.Polar || alg isa TensorKit.SDD
69-
foreach(blocks(W)) do (c, b)
70-
return _polarsdd!(b)
71-
end
72-
elseif alg isa TensorKit.SVD
73-
foreach(blocks(W)) do (c, b)
74-
return _polarsvd!(b)
75-
end
76-
elseif alg isa PolarNewton
77-
foreach(blocks(W)) do (c, b)
78-
return _polarnewton!(b)
79-
end
80-
else
81-
throw(ArgumentError("unkown algorithm for projectisometric!: alg = $alg"))
55+
projectisometric(W::AbstractTensorMap; kwargs...) = projectisometric!(copy(W); kwargs...)
56+
function projectisometric!(W::AbstractTensorMap;
57+
alg::AbstractAlgorithm=MAK.select_algorithm(left_polar!, W))
58+
TensorKit.foreachblock(W) do c, (b,)
59+
return _left_polar!(b, alg)
8260
end
8361
return W
8462
end
8563

64+
function projectcomplement(X::AbstractTensorMap, W::AbstractTensorMap, kwargs...)
65+
return projectcomplement!(copy(X), W; kwargs...)
66+
end
8667
function projectcomplement!(X::AbstractTensorMap, W::AbstractTensorMap;
8768
tol=10 * scalareps(X))
8869
P = W' * X
@@ -97,18 +78,6 @@ function projectcomplement!(X::AbstractTensorMap, W::AbstractTensorMap;
9778
return X
9879
end
9980

100-
projecthermitian(W::AbstractTensorMap) = projecthermitian!(copy(W))
101-
projectantihermitian(W::AbstractTensorMap) = projectantihermitian!(copy(W))
102-
103-
function projectisometric(W::AbstractTensorMap;
104-
alg=default_svd_alg(W))
105-
return projectisometric!(copy(W); alg=alg)
106-
end
107-
function projectcomplement(X::AbstractTensorMap, W::AbstractTensorMap,
108-
tol=10 * scalareps(X))
109-
return projectcomplement!(copy(X), W; tol=tol)
110-
end
111-
11281
include("auxiliary.jl")
11382
include("grassmann.jl")
11483
include("stiefel.jl")

src/auxiliary.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -70,37 +70,43 @@ function _subtractone!(a::AbstractMatrix)
7070
view(a, diagind(a)) .= view(a, diagind(a)) .- 1
7171
return a
7272
end
73-
function _polarsdd!(A::StridedMatrix)
74-
U, S, V = svd!(A; alg=LinearAlgebra.DivideAndConquer())
75-
return mul!(A, U, V')
76-
end
77-
function _polarsvd!(A::StridedMatrix)
78-
U, S, V = svd!(A; alg=LinearAlgebra.QRIteration())
79-
return mul!(A, U, V')
73+
74+
# TODO: _left_polar! is more or less the same as MAK.left_polar! but doesn't compute the P
75+
# which is not needed here. Can we unify this?
76+
function _left_polar!(A::StridedMatrix,
77+
alg::PolarViaSVD=PolarViaSVD(LAPACK_DivideAndConquer()))
78+
U, _, Vᴴ = svd_compact!(A, alg.svdalg)
79+
return mul!(A, U, Vᴴ)
8080
end
81+
82+
# TODO: can we move this to a dedicated MAK algorithm?
83+
MatrixAlgebraKit.@algdef PolarNewton
84+
85+
_left_polar!(A::StridedMatrix, alg::PolarNewton) = _polarnewton!(A; alg.kwargs...)
8186
function _polarnewton!(A::StridedMatrix; tol=10 * scalareps(A), maxiter=5)
8287
m, n = size(A)
8388
@assert m >= n
8489
A2 = copy(A)
85-
Q, R = qr!(A2)
86-
Ri = ldiv!(UpperTriangular(R)', TensorKit.MatrixAlgebra.one!(similar(R)))
90+
Q, R = LinearAlgebra.qr!(A2)
91+
Ri = ldiv!(UpperTriangular(R)', MatrixAlgebraKit.one!(similar(R)))
8792
R, Ri = _avgdiff!(R, Ri)
8893
i = 1
8994
R2 = view(A, 1:n, 1:n)
9095
fill!(view(A, (n + 1):m, 1:n), zero(eltype(A)))
9196
copyto!(R2, R)
9297
while maximum(abs, Ri) > tol
9398
if i == maxiter # if not converged by now, fall back to sdd
94-
_polarsdd!(Ri)
99+
_left_polar!(Ri)
95100
break
96101
end
97-
Ri = ldiv!(lu!(R2)', TensorKit.MatrixAlgebra.one!(Ri))
102+
Ri = ldiv!(lu!(R2)', MatrixAlgebraKit.one!(Ri))
98103
R, Ri = _avgdiff!(R, Ri)
99104
copyto!(R2, R)
100105
i += 1
101106
end
102107
return lmul!(Q, A)
103108
end
109+
104110
# in place computation of the average and difference of two arrays
105111
function _avgdiff!(A::AbstractArray, B::AbstractArray)
106112
axes(A) == axes(B) || throw(DimensionMismatch())
@@ -124,7 +130,7 @@ end
124130
function _stiefelexp(W::StridedMatrix, A::StridedMatrix, Z::StridedMatrix, α)
125131
n, p = size(W)
126132
r = min(2 * p, n)
127-
QQ, _ = qr!([W Z])
133+
QQ, _ = LinearAlgebra.qr!([W Z])
128134
Q = similar(W, n, r - p)
129135
@inbounds for j in Base.OneTo(r - p)
130136
for i in Base.OneTo(n)
@@ -139,7 +145,7 @@ function _stiefelexp(W::StridedMatrix, A::StridedMatrix, Z::StridedMatrix, α)
139145
A2[1:p, (p + 1):end] .= (-α) .* (R')
140146
A2[(p + 1):end, (p + 1):end] .= 0
141147
U = [W Q] * exp(A2)
142-
U = _polarnewton!(U)
148+
U = _left_polar!(U, PolarNewton())
143149
W′ = U[:, 1:p]
144150
Q′ = U[:, (p + 1):end]
145151
R′ = R
@@ -152,7 +158,7 @@ function _stiefellog(Wold::StridedMatrix, Wnew::StridedMatrix;
152158
r = min(2 * p, n)
153159
P = Wold' * Wnew
154160
dW = Wnew - Wold * P
155-
QQ, _ = qr!([Wold dW])
161+
QQ, _ = LinearAlgebra.qr!([Wold dW])
156162
Q = similar(Wold, n, r - p)
157163
@inbounds for j in Base.OneTo(r - p)
158164
for i in Base.OneTo(n)
@@ -161,23 +167,17 @@ function _stiefellog(Wold::StridedMatrix, Wnew::StridedMatrix;
161167
end
162168
Q = lmul!(QQ, Q)
163169
R = Q' * dW
164-
Wext = [Wold Q]
165-
F = qr!([P; R])
166-
U = lmul!(F.Q, TensorKit.MatrixAlgebra.one!(similar(P, r, r)))
170+
F = LinearAlgebra.qr!([P; R])
171+
U = lmul!(F.Q, MatrixAlgebraKit.one!(similar(P, r, r)))
167172
U[1:p, 1:p] .= P
168173
U[(p + 1):r, 1:p] .= R
169174
X = view(U, 1:p, (p + 1):r)
170175
Y = view(U, (p + 1):r, (p + 1):r)
171176
if p < n
172-
YSVD = svd!(Y)
173-
mul!(X, X * (YSVD.V), (YSVD.U)')
174-
UsqrtS = YSVD.U
175-
@inbounds for j in 1:size(UsqrtS, 2)
176-
s = sqrt(YSVD.S[j])
177-
@simd for i in 1:size(UsqrtS, 1)
178-
UsqrtS[i, j] *= s
179-
end
180-
end
177+
USVᴴ = svd_compact!(Y)
178+
mul!(X, X * USVᴴ[3]', USVᴴ[1]')
179+
diagview(USVᴴ[2]) .= sqrt.(diagview(USVᴴ[2]))
180+
UsqrtS = rmul!(USVᴴ[1], USVᴴ[2])
181181
mul!(Y, UsqrtS, UsqrtS')
182182
end
183183
logU = _projectantihermitian!(log(U))

src/grassmann.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ module Grassmann
77
using TensorKit
88
using TensorKit: similarstoragetype, SectorDict
99
using ..TensorKitManifolds: projecthermitian!, projectantihermitian!,
10-
projectisometric!, projectcomplement!, PolarNewton,
11-
default_svd_alg
10+
projectisometric!, projectcomplement!, PolarNewton
1211
import ..TensorKitManifolds: base, checkbase, inner, retract, transport, transport!
1312

1413
# special type to store tangent vectors using Z
@@ -32,8 +31,8 @@ end
3231

3332
# output type of U, S, V in tsvd
3433
function _tsvd_types(Z::AbstractTensorMap)
35-
TUSV = Core.Compiler.return_type(tsvd, Tuple{typeof(Z)})
36-
TU, TS, TV, = TUSV.types
34+
TUSV = Core.Compiler.return_type(svd_compact, Tuple{typeof(Z)})
35+
TU, TS, TV = TUSV.types
3736
return TU, TS, TV
3837
end
3938

@@ -60,7 +59,7 @@ function Base.getproperty(Δ::GrassmannTangent, sym::Symbol)
6059
elseif sym (:U, :S, :V)
6160
v = Base.getfield(Δ, sym)
6261
v !== nothing && return v
63-
U, S, V, = tsvd.Z; alg=default_svd_alg.Z))
62+
U, S, V = svd_compact.Z)
6463
Base.setfield!(Δ, :U, U)
6564
Base.setfield!(Δ, :S, S)
6665
Base.setfield!(Δ, :V, V)
@@ -198,7 +197,7 @@ function invretract(Wold::AbstractTensorMap, Wnew::AbstractTensorMap; alg=nothin
198197
space(Wold) == space(Wnew) || throw(SpaceMismatch())
199198
WodWn = Wold' * Wnew # V' * cos(S) * V * Y
200199
Wneworth = Wnew - Wold * WodWn
201-
Vd, cS, VY = tsvd!(WodWn; alg=default_svd_alg(WodWn))
200+
Vd, cS, VY = svd_compact!(WodWn)
202201
Scmplx = acos(cS)
203202
# acos always returns a complex TensorMap. We cast back to real if possible.
204203
S = scalartype(WodWn) <: Real && isreal(sectortype(Scmplx)) ? real(Scmplx) : Scmplx

src/unitary.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using TensorKit
77
import TensorKit: similarstoragetype, SectorDict
88
using ..TensorKitManifolds: projectantihermitian!, projectisometric!, PolarNewton
99
import ..TensorKitManifolds: base, checkbase, inner, retract, transport, transport!
10+
import MatrixAlgebraKit as MAK
1011

1112
struct UnitaryTangent{T<:AbstractTensorMap,TA<:AbstractTensorMap}
1213
W::T
@@ -82,10 +83,11 @@ end
8283
project(X, W; metric=:euclidean) = project!(copy(X), W; metric=:euclidean)
8384

8485
# geodesic retraction, coincides with Stiefel retraction (which is not geodesic for p < n)
85-
function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α; alg=nothing)
86+
function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α;
87+
alg=MAK.select_algorithm(left_polar!, W))
8688
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
8789
E = exp* Δ.A)
88-
W′ = projectisometric!(W * E; alg=SDD())
90+
W′ = projectisometric!(W * E; alg)
8991
A′ = Δ.A
9092
return W′, UnitaryTangent(W′, A′)
9193
end
@@ -104,7 +106,7 @@ function transport!(Θ::UnitaryTangent, W::AbstractTensorMap, Δ::UnitaryTangent
104106
end
105107
function transport::UnitaryTangent, W::AbstractTensorMap, Δ::UnitaryTangent, α::Real, W′;
106108
alg=:stiefel)
107-
return transport!(copy(Θ), W, Δ, α, W′; alg=alg)
109+
return transport!(copy(Θ), W, Δ, α, W′; alg)
108110
end
109111

110112
# transport_parallel correspondings to the torsion-free Levi-Civita connection

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ const α = 0.75
88

99
@testset "Grassmann with space $V" for V in spaces
1010
for T in (Float64,)
11-
W, = leftorth(randn(T, V * V * V, V * V); alg=Polar())
11+
W, = left_polar(randn(T, V * V * V, V * V))
1212
X = randn(T, space(W))
1313
Y = randn(T, space(W))
1414
Δ = @inferred Grassmann.project(X, W)
@@ -124,7 +124,7 @@ end
124124

125125
@testset "Unitary with space $V" for V in spaces
126126
for T in (Float64, ComplexF64)
127-
W, = leftorth(randn(T, V * V * V, V * V); alg=Polar())
127+
W, = left_polar(randn(T, V * V * V, V * V))
128128
X = randn(T, space(W))
129129
Y = randn(T, space(W))
130130
Δ = @inferred Unitary.project(X, W)

0 commit comments

Comments
 (0)