Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ext/TensorKitChainRulesCoreExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
return adjoint(A), adjoint_pullback
end

function ChainRulesCore.rrule(::typeof(twist), A::AbstractTensorMap, is; inv::Bool=false)
tA = twist(A, is; inv)
twist_pullback(ΔA) = NoTangent(), twist(unthunk(ΔA), is; inv=!inv), NoTangent()
return tA, twist_pullback

Check warning on line 83 in ext/TensorKitChainRulesCoreExt/linalg.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/linalg.jl#L80-L83

Added lines #L80 - L83 were not covered by tests
end

function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
return dot(a, b), dot_pullback
Expand Down
2 changes: 2 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
test_rrule(*, A, C)

test_rrule(permute, A, ((1, 3, 2), (5, 4)))
test_rrule(twist, A, 1)
test_rrule(twist, A, [1, 3])

D = randn(T, V[1] ⊗ V[2] ← V[3])
E = randn(T, V[4] ← V[5])
Expand Down
Loading