Skip to content

Commit ffffaa8

Browse files
committed
Update tests and fixes
Small fixes
1 parent 6feda9f commit ffffaa8

File tree

4 files changed

+17
-18
lines changed

4 files changed

+17
-18
lines changed

ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import TensorOperations as TO
1212
using TensorOperations: promote_contract, tensoralloc_add, tensoralloc_contract
1313
using VectorInterface: promote_scale, promote_add
1414

15-
using MatrixAlgebraKit: TruncationStrategy,
15+
using MatrixAlgebraKit: MatrixAlgebraKit, TruncationStrategy,
1616
svd_compact_pullback!, eig_full_pullback!, eigh_full_pullback!
1717

1818
include("utility.jl")

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
55
kwargs...)
66
# TODO: I think we can use tsvd! here without issues because we don't actually require
77
# the data of `t` anymore.
8-
USVᴴ = tsvd(t; trunc=TensorKit.notrunc(), alg)
8+
USVᴴ = tsvd(t; trunc=TensorKit.notrunc(), kwargs...)
99

1010
if trunc != TensorKit.notrunc() && !isempty(blocksectors(t))
1111
USVᴴ′ = MatrixAlgebraKit.truncate!(svd_trunc!, USVᴴ, trunc)
@@ -16,7 +16,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
1616
function tsvd!_pullback(ΔUSVᴴ′)
1717
ΔUSVᴴ = unthunk.(ΔUSVᴴ′)
1818
Δt = similar(t)
19-
foreachblock(Δt) do (c, b)
19+
foreachblock(Δt) do c, (b,)
2020
USVᴴc = block.(USVᴴ, Ref(c))
2121
ΔUSVᴴc = block.(ΔUSVᴴ, Ref(c))
2222
svd_compact_pullback!(b, USVᴴc, ΔUSVᴴc)
@@ -49,7 +49,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kw
4949
function eig!_pullback(ΔDV′)
5050
ΔDV = unthunk.(ΔDV′)
5151
Δt = similar(t)
52-
foreachblock(Δt) do (c, b)
52+
foreachblock(Δt) do c, (b,)
5353
DVc = block.(DV, Ref(c))
5454
ΔDVc = block.(ΔDV, Ref(c))
5555
eig_full_pullback!(b, DVc, ΔDVc)
@@ -68,7 +68,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; k
6868
function eigh!_pullback(ΔDV′)
6969
ΔDV = unthunk.(ΔDV′)
7070
Δt = similar(t)
71-
foreachblock(Δt) do (c, b)
71+
foreachblock(Δt) do c, (b,)
7272
DVc = block.(DV, Ref(c))
7373
ΔDVc = block.(ΔDV, Ref(c))
7474
eigh_full_pullback!(b, DVc, ΔDVc)

src/tensors/factorizations/factorizations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ include("matrixalgebrakit.jl")
4545
include("truncation.jl")
4646
include("deprecations.jl")
4747

48-
4948
function isisometry(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
5049
t = permute(t, (p₁, p₂); copy=false)
5150
return isisometry(t)

test/ad.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
398398
test_rrule(eigh′, H; atol, output_tangent=(ΔD, ΔU))
399399
end
400400

401-
let (U, S, V, ϵ) = tsvd(A)
401+
let (U, S, V) = tsvd(A)
402402
ΔU = randn(scalartype(U), space(U))
403403
ΔS = randn(scalartype(S), space(S))
404404
ΔV = randn(scalartype(V), space(V))
@@ -408,54 +408,54 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
408408
mul!(block(ΔU, c), block(U, c), Diagonal(imag(diag(b))), -im, 1)
409409
end
410410
end
411-
test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0))
411+
test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV))
412412

413413
allS = mapreduce(x -> diag(x[2]), vcat, blocks(S))
414414
truncval = (maximum(allS) + minimum(allS)) / 2
415-
U, S, V, ϵ = tsvd(A; trunc=truncerr(truncval))
415+
U, S, V = tsvd(A; trunc=truncerr(truncval))
416416
ΔU = randn(scalartype(U), space(U))
417417
ΔS = randn(scalartype(S), space(S))
418418
ΔV = randn(scalartype(V), space(V))
419419
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
420-
test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
420+
test_rrule(tsvd, A; atol, output_tangent=(ΔU, ΔS, ΔV),
421421
fkwargs=(; trunc=truncerr(truncval)))
422422
end
423423

424-
let (U, S, V, ϵ) = tsvd(B)
424+
let (U, S, V) = tsvd(B)
425425
ΔU = randn(scalartype(U), space(U))
426426
ΔS = randn(scalartype(S), space(S))
427427
ΔV = randn(scalartype(V), space(V))
428428
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
429-
test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0))
429+
test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV))
430430

431431
Vtrunc = spacetype(S)(TensorKit.SectorDict(c => ceil(Int, size(b, 1) / 2)
432432
for (c, b) in blocks(S)))
433433

434-
U, S, V, ϵ = tsvd(B; trunc=truncspace(Vtrunc))
434+
U, S, V = tsvd(B; trunc=truncspace(Vtrunc))
435435
ΔU = randn(scalartype(U), space(U))
436436
ΔS = randn(scalartype(S), space(S))
437437
ΔV = randn(scalartype(V), space(V))
438438
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
439-
test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0),
439+
test_rrule(tsvd, B; atol, output_tangent=(ΔU, ΔS, ΔV),
440440
fkwargs=(; trunc=truncspace(Vtrunc)))
441441
end
442442

443-
let (U, S, V, ϵ) = tsvd(C)
443+
let (U, S, V) = tsvd(C)
444444
ΔU = randn(scalartype(U), space(U))
445445
ΔS = randn(scalartype(S), space(S))
446446
ΔV = randn(scalartype(V), space(V))
447447
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
448-
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0))
448+
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV))
449449

450450
c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])),
451451
blocks(S))
452452
trunc = truncdim(round(Int, 2 * dim(c)))
453-
U, S, V, ϵ = tsvd(C; trunc)
453+
U, S, V = tsvd(C; trunc)
454454
ΔU = randn(scalartype(U), space(U))
455455
ΔS = randn(scalartype(S), space(S))
456456
ΔV = randn(scalartype(V), space(V))
457457
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
458-
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV, 0.0), fkwargs=(; trunc))
458+
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV), fkwargs=(; trunc))
459459
end
460460

461461
let D = LinearAlgebra.eigvals(C)

0 commit comments

Comments
 (0)