Skip to content

Commit 0ae738c

Browse files
committed
Add rrule for twist
1 parent f09e40c commit 0ae738c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
7777
return adjoint(A), adjoint_pullback
7878
end
7979

80+
function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool=false)
81+
tA = twist(A, is; inv)
82+
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv=!inv), NoTangent()
83+
return tA, twist_pullback
84+
end
85+
8086
function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
8187
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
8288
return dot(a, b), dot_pullback

0 commit comments

Comments
 (0)