Skip to content

Commit 80aa0c0

Browse files
committed
switch from === to == for checking base point.
1 parent b402c2c commit 80aa0c0

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

src/grassmann.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ end
5656

5757
Base.getindex::GrassmannTangent) = Δ.Z
5858
base::GrassmannTangent) = Δ.W
59-
checkbase(Δ₁::GrassmannTangent, Δ₂::GrassmannTangent) = Δ₁.W === Δ₂.W ? Δ₁.W :
59+
checkbase(Δ₁::GrassmannTangent, Δ₂::GrassmannTangent) = Δ₁.W == Δ₂.W ? Δ₁.W :
6060
throw(ArgumentError("tangent vectors with different base points"))
6161

6262
function Base.getproperty::GrassmannTangent, sym::Symbol)
@@ -142,7 +142,7 @@ function inner(W::AbstractTensorMap, Δ₁::GrassmannTangent, Δ₂::GrassmannTa
142142
end
143143

144144
function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α)
145-
W === base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
145+
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
146146
U, S, V = Δ.U, Δ.S, Δ.V
147147
WVd = W*V'
148148
cS = cos*S)
@@ -158,7 +158,7 @@ function retract(W::AbstractTensorMap, Δ::GrassmannTangent, α)
158158
end
159159

160160
function transport!::GrassmannTangent, W::AbstractTensorMap, Δ::GrassmannTangent, α, W′)
161-
W === checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
161+
W == checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
162162
U, S, V = Δ.U, Δ.S, Δ.V
163163
WVd = W*V'
164164
cS = cos*S)

src/stiefel.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ end
6565

6666
Base.getindex::StiefelTangent) = Δ.W * Δ.A + Δ.Z
6767
base::StiefelTangent) = Δ.W
68-
checkbase(Δ₁::StiefelTangent, Δ₂::StiefelTangent) = Δ₁.W === Δ₂.W ? Δ₁.W :
68+
checkbase(Δ₁::StiefelTangent, Δ₂::StiefelTangent) = Δ₁.W == Δ₂.W ? Δ₁.W :
6969
throw(ArgumentError("tangent vectors with different base points"))
7070

7171
function Base.getproperty::StiefelTangent, sym::Symbol)
@@ -228,7 +228,7 @@ project_canonical(X, W) = project_canonical!(copy(X), W)
228228
# geodesic retraction for canonical metric using exponential
229229
# can be computed efficiently: O(np^2) + O(p^3)
230230
function retract_exp(W::AbstractTensorMap, Δ::StiefelTangent, α::Real)
231-
W === base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
231+
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
232232
A, Z, U, S, V, A2 = Δ.A, Δ.Z, Δ.U, Δ.S, Δ.V, Δ.A2
233233
UU = catdomain(W*V', U)
234234
VV = catcodomain(V, zero(V))
@@ -248,7 +248,7 @@ end
248248
# can be computed efficiently: O(np^2) + O(p^3)
249249
function transport_exp!::StiefelTangent, W::AbstractTensorMap,
250250
Δ::StiefelTangent, α::Real, W′)
251-
W === checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
251+
W == checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
252252
U, S, V, A2 = Δ.U, Δ.S, Δ.V, Δ.A2
253253
UU = catdomain(W*V', U)
254254
P = catcodomain(zero(S), one(S))
@@ -264,7 +264,7 @@ transport_exp(Θ::StiefelTangent, W::AbstractTensorMap, Δ::StiefelTangent, α::
264264
# Cayley retraction, slightly more efficient than above?
265265
# can be computed efficiently: O(np^2) + O(p^3)
266266
function retract_cayley(W::AbstractTensorMap, Δ::StiefelTangent, α::Real)
267-
W === base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
267+
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
268268
A, Z = Δ.A, Δ.Z
269269
ZdZ = Z'*Z
270270
X = axpy!^2/4, ZdZ, axpy!(-α/2, A, one(A)))
@@ -281,7 +281,7 @@ end
281281
# can be computed efficiently: O(np^2) + O(p^3)
282282
function transport_cayley!::StiefelTangent, W::AbstractTensorMap, Δ::StiefelTangent,
283283
α::Real, W′)
284-
W === checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
284+
W == checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
285285
A, Z = Δ.A, Δ.Z
286286
X = axpy!^2/4, Z'*Z, axpy!(-α/2, A, one(A)))
287287
A′ = Θ.A

src/unitary.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ end
2020
Base.copy::UnitaryTangent) = UnitaryTangent.W, copy.A))
2121
Base.getindex::UnitaryTangent) = Δ.W * Δ.A
2222
base::UnitaryTangent) = Δ.W
23-
checkbase(Δ₁::UnitaryTangent, Δ₂::UnitaryTangent) = Δ₁.W === Δ₂.W ? Δ₁.W :
23+
checkbase(Δ₁::UnitaryTangent, Δ₂::UnitaryTangent) = Δ₁.W == Δ₂.W ? Δ₁.W :
2424
throw(ArgumentError("tangent vectors with different base points"))
2525

2626
function Base.getproperty::UnitaryTangent, sym::Symbol)
@@ -84,7 +84,7 @@ project(X, W) = project!(copy(X), W)
8484

8585
# geodesic retraction, coincides with Stiefel retraction (which is not geodesic for p < n)
8686
function retract(W::AbstractTensorMap, Δ::UnitaryTangent, α)
87-
W === base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
87+
W == base(Δ) || throw(ArgumentError("not a valid tangent vector at base point"))
8888
E = exp*Δ.A)
8989
# W′, = leftorth!(W*E; alg = QRpos()) # additional QRpos for stability
9090
W′ = W*E # no additional QRpos as this changes space
@@ -114,7 +114,7 @@ end
114114
function transport_parallel!::UnitaryTangent,
115115
W::AbstractTensorMap,
116116
Δ::UnitaryTangent, α, W′)
117-
W === checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
117+
W == checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
118118
E = exp((α/2)*Δ.A)
119119
A′ = projectantihermitian!(E'*Θ.A*E) # exra projection for stability
120120
return UnitaryTangent(W′, A′)
@@ -125,7 +125,7 @@ transport_parallel(Θ::UnitaryTangent, W::AbstractTensorMap, Δ::UnitaryTangent,
125125
function transport_stiefel!::UnitaryTangent,
126126
W::AbstractTensorMap,
127127
Δ::UnitaryTangent, α, W′)
128-
W === checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
128+
W == checkbase(Δ,Θ) || throw(ArgumentError("not a valid tangent vector at base point"))
129129
A′ = Θ.A
130130
return UnitaryTangent(W′, A′)
131131
end

0 commit comments

Comments
 (0)