Skip to content

Commit a5811da

Browse files
authored
Add rrule for transpose (#319)
* add rrule for transpose * add test rrule transpose
1 parent fc2413b commit a5811da

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ function ChainRulesCore.rrule(
6969
return permute(tsrc, p; copy = true), permute_pullback
7070
end
7171

72+
function ChainRulesCore.rrule(
73+
::typeof(transpose), tsrc::AbstractTensorMap, p::Index2Tuple; copy::Bool = false
74+
)
75+
function transpose_pullback(Δtdst)
76+
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc)
77+
return NoTangent(), transpose(unthunk(Δtdst), invp; copy = true), NoTangent()
78+
end
79+
return transpose(tsrc, p; copy = true), transpose_pullback
80+
end
81+
7282
function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
7383
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A))
7484
return tr(A), tr_pullback

test/autodiff/ad.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ for V in spacelist
293293
C = randn(T, domain(A), codomain(A))
294294
test_rrule(*, A, C)
295295

296+
test_rrule(transpose, A, ((2, 5, 4), (1, 3)))
296297
symmetricbraiding && test_rrule(permute, A, ((1, 3, 2), (5, 4)))
297298
test_rrule(twist, A, 1)
298299
test_rrule(twist, A, [1, 3])

0 commit comments

Comments
 (0)