Skip to content

Commit f9ed03e

Browse files
committed
exchange sqrt and invsqrt in hope of fixing without thinking
1 parent 3e45ff2 commit f9ed03e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ext/TensorKitChainRulesCoreExt/constructors.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function ChainRulesCore.rrule(::Type{TensorMap{T}}, data::DenseVector,
2121
function TensorMap_pullback(Δt_)
2222
Δt = copy(unthunk(Δt_))
2323
for (c, b) in blocks(Δt)
24-
scale!(b, TensorKit.sqrtdim(c))
24+
scale!(b, TensorKit.invsqrtdim(c))
2525
end
2626
∂data = P(Δt.data)
2727
return NoTangent(), ∂data, NoTangent()
@@ -37,7 +37,7 @@ function ChainRulesCore.rrule(::Type{<:DiagonalTensorMap}, data::DenseVector, ar
3737
# unclear if we're allowed to modify/take ownership of the input
3838
Δt = copy(unthunk(Δt_))
3939
for (c, b) in blocks(Δt)
40-
scale!(b, TensorKit.sqrtdim(c))
40+
scale!(b, TensorKit.invsqrtdim(c))
4141
end
4242
∂data = P(Δt.data)
4343
return NoTangent(), ∂data, NoTangent()
@@ -51,7 +51,7 @@ function ChainRulesCore.rrule(::typeof(Base.getproperty), t::TensorMap, prop::Sy
5151
# unclear if we're allowed to modify/take ownership of the input
5252
t′ = typeof(t)(copy(unthunk(Δdata)), t.space)
5353
for (c, b) in blocks(t′)
54-
scale!(b, TensorKit.invsqrtdim(c))
54+
scale!(b, TensorKit.sqrtdim(c))
5555
end
5656
return NoTangent(), t′, NoTangent()
5757
end
@@ -70,7 +70,7 @@ function ChainRulesCore.rrule(::typeof(Base.getproperty), t::DiagonalTensorMap,
7070
# unclear if we're allowed to modify/take ownership of the input
7171
t′ = typeof(t)(copy(unthunk(Δdata)), t.domain)
7272
for (c, b) in blocks(t′)
73-
scale!(b, TensorKit.invsqrtdim(c))
73+
scale!(b, TensorKit.sqrtdim(c))
7474
end
7575
return NoTangent(), t′, NoTangent()
7676
end

0 commit comments

Comments
 (0)