Skip to content

Commit a012324

Browse files
authored
Merge pull request #169 from Jutho/jh/tensorstructure
Small fixes
2 parents cddfaa6 + c2a44bd commit a012324

File tree

3 files changed

+38
-38
lines changed

3 files changed

+38
-38
lines changed

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,7 @@ end
184184
#
185185
function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector,
186186
Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
187-
atol::Real=0,
188-
rtol::Real=atol > 0 ? 0 : eps(eltype(S))^(3 / 4))
187+
tol::Real=default_pullback_gaugetol(S))
189188

190189
# Basic size checks and determination
191190
m, n = size(U, 1), size(Vd, 2)
@@ -214,8 +213,7 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
214213
Vp = view(Vd, 1:p, :)'
215214
Sp = view(S, 1:p)
216215

217-
# tolerance and rank
218-
tol = atol > 0 ? atol : rtol * S[1, 1]
216+
# rank
219217
r = findlast(>=(tol), S)
220218

221219
# compute antihermitian part of projection of ΔU and ΔV onto U and V
@@ -302,16 +300,12 @@ function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector
302300
end
303301

304302
function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
305-
atol::Real=0,
306-
rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4))
303+
tol::Real=default_pullback_gaugetol(D))
307304

308305
# Basic size checks and determination
309306
n = LinearAlgebra.checksquare(V)
310307
n == length(D) || throw(DimensionMismatch())
311308

312-
# tolerance and rank
313-
tol = atol > 0 ? atol : rtol * maximum(abs, D)
314-
315309
if !(ΔV isa AbstractZero)
316310
VdΔV = V' * ΔV
317311

@@ -345,16 +339,12 @@ function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix
345339
end
346340

347341
function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
348-
atol::Real=0,
349-
rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4))
342+
tol::Real=default_pullback_gaugetol(D))
350343

351344
# Basic size checks and determination
352345
n = LinearAlgebra.checksquare(V)
353346
n == length(D) || throw(DimensionMismatch())
354347

355-
# tolerance and rank
356-
tol = atol > 0 ? atol : rtol * maximum(abs, D)
357-
358348
if !(ΔV isa AbstractZero)
359349
VdΔV = V' * ΔV
360350
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)
@@ -379,10 +369,8 @@ function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatri
379369
end
380370

381371
function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR;
382-
atol::Real=0,
383-
rtol::Real=atol > 0 ? 0 : eps(real(eltype(R)))^(3 / 4))
372+
tol::Real=default_pullback_gaugetol(R))
384373
Rd = view(R, diagind(R))
385-
tol = atol > 0 ? atol : rtol * maximum(abs, Rd)
386374
p = findlast(>=(tol) abs, Rd)
387375
m, n = size(R)
388376

@@ -432,10 +420,8 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix,
432420
end
433421

434422
function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL, ΔQ;
435-
atol::Real=0,
436-
rtol::Real=atol > 0 ? 0 : eps(real(eltype(L)))^(3 / 4))
423+
tol::Real=default_pullback_gaugetol(L))
437424
Ld = view(L, diagind(L))
438-
tol = atol > 0 ? atol : rtol * maximum(abs, Ld)
439425
p = findlast(>=(tol) abs, Ld)
440426
m, n = size(L)
441427

@@ -483,3 +469,8 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix,
483469
ldiv!(LowerTriangular(L11)', ΔA1)
484470
return ΔA
485471
end
472+
473+
function default_pullback_gaugetol(a)
474+
n = norm(a, Inf)
475+
return eps(eltype(n))^(3 / 4) * max(n, one(n))
476+
end

src/tensors/tensoroperations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ function TO.tensoralloc(::Type{TT},
1313
A = storagetype(TT)
1414
dim = fusionblockstructure(structure).totaldim
1515
data = TO.tensoralloc(A, dim, istemp, allocator)
16-
return TT(data, structure)
16+
# return TT(data, structure)
17+
return TensorMap{T}(data, structure)
1718
end
1819

1920
function TO.tensorfree!(t::TensorMap, allocator=TO.DefaultAllocator())

test/planar.jl

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,14 @@ end
132132
GL′ = force_planar(GL)
133133
GR′ = force_planar(GR)
134134

135-
@tensor y[-1 -2; -3] := GL[-1 2; 1] * x[1 3; 4] * O[2 -2; 3 5] * GR[4 5; -3]
136-
@planar y′[-1 -2; -3] := GL′[-1 2; 1] * x′[1 3; 4] * O′[2 -2; 3 5] * GR′[4 5; -3]
137-
@test force_planar(y) y′
135+
for alloc in
136+
(TensorOperations.DefaultAllocator(), TensorOperations.ManualAllocator())
137+
@tensor allocator = alloc y[-1 -2; -3] := GL[-1 2; 1] * x[1 3; 4] *
138+
O[2 -2; 3 5] * GR[4 5; -3]
139+
@planar allocator = alloc y′[-1 -2; -3] := GL′[-1 2; 1] * x′[1 3; 4] *
140+
O′[2 -2; 3 5] * GR′[4 5; -3]
141+
@test force_planar(y) y′
142+
end
138143

139144
# ∂AC2
140145
# -------
@@ -193,21 +198,24 @@ end
193198
ρ′ = force_planar(ρ)
194199
h′ = force_planar(h)
195200

196-
@tensor begin
197-
C = (((((((h[9 3 4; 5 1 2] * u[1 2; 7 12]) * conj(u[3 4; 11 13])) *
198-
(u[8 5; 15 6] * w[6 7; 19])) *
199-
(conj(u[8 9; 17 10]) * conj(w[10 11; 22]))) *
200-
((w[12 14; 20] * conj(w[13 14; 23])) * ρ[18 19 20; 21 22 23])) *
201-
w[16 15; 18]) * conj(w[16 17; 21]))
202-
end
203-
@planar begin
204-
C′ = (((((((h′[9 3 4; 5 1 2] * u′[1 2; 7 12]) * conj(u′[3 4; 11 13])) *
205-
(u′[8 5; 15 6] * w′[6 7; 19])) *
206-
(conj(u′[8 9; 17 10]) * conj(w′[10 11; 22]))) *
207-
((w′[12 14; 20] * conj(w′[13 14; 23])) * ρ′[18 19 20; 21 22 23])) *
208-
w′[16 15; 18]) * conj(w′[16 17; 21]))
201+
for alloc in
202+
(TensorOperations.DefaultAllocator(), TensorOperations.ManualAllocator())
203+
@tensor allocator = alloc begin
204+
C = (((((((h[9 3 4; 5 1 2] * u[1 2; 7 12]) * conj(u[3 4; 11 13])) *
205+
(u[8 5; 15 6] * w[6 7; 19])) *
206+
(conj(u[8 9; 17 10]) * conj(w[10 11; 22]))) *
207+
((w[12 14; 20] * conj(w[13 14; 23])) * ρ[18 19 20; 21 22 23])) *
208+
w[16 15; 18]) * conj(w[16 17; 21]))
209+
end
210+
@planar allocator = alloc begin
211+
C′ = (((((((h′[9 3 4; 5 1 2] * u′[1 2; 7 12]) * conj(u′[3 4; 11 13])) *
212+
(u′[8 5; 15 6] * w′[6 7; 19])) *
213+
(conj(u′[8 9; 17 10]) * conj(w′[10 11; 22]))) *
214+
((w′[12 14; 20] * conj(w′[13 14; 23])) * ρ′[18 19 20; 21 22 23])) *
215+
w′[16 15; 18]) * conj(w′[16 17; 21]))
216+
end
217+
@test C C′
209218
end
210-
@test C C′
211219
end
212220

213221
@testset "Issue 93" begin

0 commit comments

Comments
 (0)