We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
twist
1 parent f09e40c commit 0ae738cCopy full SHA for 0ae738c
ext/TensorKitChainRulesCoreExt/linalg.jl
@@ -77,6 +77,12 @@ function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
77
return adjoint(A), adjoint_pullback
78
end
79
80
+function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool=false)
81
+ tA = twist(A, is; inv)
82
+ twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv=!inv), NoTangent()
83
+ return tA, twist_pullback
84
+end
85
+
86
function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
87
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
88
return dot(a, b), dot_pullback
0 commit comments