Skip to content

Commit d79bd90

Browse files
committed
Fix uninitialized cotangents
1 parent df810a9 commit d79bd90

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ext/TensorKitChainRulesCoreExt/factorizations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
1515

1616
function tsvd!_pullback(ΔUSVᴴ′)
1717
ΔUSVᴴ = unthunk.(ΔUSVᴴ′)
18-
Δt = similar(t)
18+
Δt = zerovector(t)
1919
foreachblock(Δt) do c, (b,)
2020
USVᴴc = block.(USVᴴ, Ref(c))
2121
ΔUSVᴴc = block.(ΔUSVᴴ, Ref(c))
@@ -48,7 +48,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kw
4848

4949
function eig!_pullback(ΔDV′)
5050
ΔDV = unthunk.(ΔDV′)
51-
Δt = similar(t)
51+
Δt = zerovector(t)
5252
foreachblock(Δt) do c, (b,)
5353
DVc = block.(DV, Ref(c))
5454
ΔDVc = block.(ΔDV, Ref(c))
@@ -67,7 +67,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; k
6767

6868
function eigh!_pullback(ΔDV′)
6969
ΔDV = unthunk.(ΔDV′)
70-
Δt = similar(t)
70+
Δt = zerovector(t)
7171
foreachblock(Δt) do c, (b,)
7272
DVc = block.(DV, Ref(c))
7373
ΔDVc = block.(ΔDV, Ref(c))

0 commit comments

Comments
 (0)