Skip to content

Commit 9ef15ab

Browse files
committed
handle kwargs in twist rrule
1 parent c878317 commit 9ef15ab

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
7979
return adjoint(A), adjoint_pullback
8080
end
8181

82-
function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool = false)
83-
tA = twist(A, is; inv)
84-
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv = !inv), NoTangent()
82+
function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool = false, kwargs...)
83+
tA = twist(A, is; inv, kwargs...)
84+
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv = !inv, kwargs...), NoTangent()
8585
return tA, twist_pullback
8686
end
8787

0 commit comments

Comments
 (0)