Skip to content

Commit 5fd86c0

Browse files
committed
proper fix (hopefullly)
1 parent faaa89b commit 5fd86c0

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

ext/TensorKitChainRulesCoreExt/tensoroperations.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
6161
pΔC = _repartition(ipAB, TO.numout(pA))
6262

6363
dC = @thunk projectC(scale(ΔC, conj(β)))
64-
dA = @thunk let ipA = _repartition(invperm(linearize(pA)), A)
64+
dA = @thunk let
65+
ipA = _repartition(invperm(linearize(pA)), A)
6566
conjΔC = conjA
6667
conjB′ = conjA ? conjB : !conjB
6768
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
@@ -77,7 +78,8 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
7778
conjA ? α : conj(α), Zero(), ba...)
7879
return projectA(_dA)
7980
end
80-
dB = @thunk let ipB = _repartition(invperm(linearize(pB)), B)
81+
dB = @thunk let
82+
ipB = _repartition(invperm(linearize(pB)), B)
8183
conjΔC = conjB
8284
conjA′ = conjB ? conjA : !conjA
8385
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
@@ -123,7 +125,9 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
123125
function pullback(ΔC′)
124126
ΔC = unthunk(ΔC′)
125127
dC = @thunk projectC(scale(ΔC, conj(β)))
126-
dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)), pdA = _repartition(ip, A)
128+
dA = @thunk let
129+
ip = invperm((linearize(p)..., q[1]..., q[2]...))
130+
pdA = _repartition(ip, A)
127131
E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA))
128132
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
129133
pE = ((), trivtuple(TO.numind(q)))

ext/TensorKitChainRulesCoreExt/utility.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
# -------
33
trivtuple(N) = ntuple(identity, N)
44

5-
function _repartition(p::IndexTuple, N₁::Int)
5+
Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int)
66
length(p) >= N₁ ||
77
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
88
return TupleTools.getindices(p, trivtuple(N₁)),
99
TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁)
1010
end
11-
_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁)
11+
Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int)
12+
return _repartition(linearize(p), N₁)
13+
end
1214
function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
1315
return _repartition(p, N₁)
1416
end

0 commit comments

Comments
 (0)