@@ -14,8 +14,11 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
1414 dC = @thunk projectC (scale (ΔC, conj (β)))
1515 dA = @thunk let
1616 ipA = invperm (linearize (pA))
17- _dA = zerovector (A, promote_add (ΔC, α))
18- _dA = tensoradd! (_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj (α), Zero (), ba... )
17+ pdA = _repartition (ipA, A)
18+ TA = promote_add (ΔC, α)
19+ # TODO : allocator
20+ _dA = tensoralloc_add (TA, ΔC, pdA, conjA, Val (false ))
21+ _dA = tensoradd! (_dA, ΔC, pdA, conjA, conjA ? α : conj (α), Zero (), ba... )
1922 return projectA (_dA)
2023 end
2124 dα = @thunk let
@@ -55,34 +58,37 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
5558 function pullback (ΔC′)
5659 ΔC = unthunk (ΔC′)
5760 ipAB = invperm (linearize (pAB))
58- pΔC = (TupleTools. getindices (ipAB, trivtuple (TO. numout (pA))),
59- TupleTools. getindices (ipAB, TO. numout (pA) .+ trivtuple (TO. numin (pB))))
61+ pΔC = _repartition (ipAB, TO. numout (pA))
6062
6163 dC = @thunk projectC (scale (ΔC, conj (β)))
6264 dA = @thunk let
63- ipA = (invperm (linearize (pA)), () )
65+ ipA = _repartition (invperm (linearize (pA)), A )
6466 conjΔC = conjA
6567 conjB′ = conjA ? conjB : ! conjB
66- _dA = zerovector (A,
67- promote_contract ( scalartype (ΔC), scalartype (B), scalartype (α)))
68+ TA = promote_contract ( scalartype (ΔC), scalartype (B), scalartype (α))
69+ # TODO : allocator
6870 tB = twist (B,
6971 TupleTools. vcat (filter (x -> ! isdual (space (B, x)), pB[1 ]),
7072 filter (x -> isdual (space (B, x)), pB[2 ])))
73+ _dA = tensoralloc_contract (TA, ΔC, pΔC, conjΔC, tB, reverse (pB), conjB′, ipA,
74+ Val (false ))
7175 _dA = tensorcontract! (_dA,
7276 ΔC, pΔC, conjΔC,
7377 tB, reverse (pB), conjB′, ipA,
7478 conjA ? α : conj (α), Zero (), ba... )
7579 return projectA (_dA)
7680 end
7781 dB = @thunk let
78- ipB = (invperm (linearize (pB)), () )
82+ ipB = _repartition (invperm (linearize (pB)), B )
7983 conjΔC = conjB
8084 conjA′ = conjB ? conjA : ! conjA
81- _dB = zerovector (B,
82- promote_contract ( scalartype (ΔC), scalartype (A), scalartype (α)))
85+ TB = promote_contract ( scalartype (ΔC), scalartype (A), scalartype (α))
86+ # TODO : allocator
8387 tA = twist (A,
8488 TupleTools. vcat (filter (x -> isdual (space (A, x)), pA[1 ]),
8589 filter (x -> ! isdual (space (A, x)), pA[2 ])))
90+ _dB = tensoralloc_contract (TB, tA, reverse (pA), conjA′, ΔC, pΔC, conjΔC, ipB,
91+ Val (false ))
8692 _dB = tensorcontract! (_dB,
8793 tA, reverse (pA), conjA′,
8894 ΔC, pΔC, conjΔC, ipB,
@@ -121,12 +127,15 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensortrace!),
121127 dC = @thunk projectC (scale (ΔC, conj (β)))
122128 dA = @thunk let
123129 ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
130+ pdA = _repartition (ip, A)
124131 E = one! (TO. tensoralloc_add (scalartype (A), A, q, conjA))
125132 twist! (E, filter (x -> ! isdual (space (E, x)), codomainind (E)))
126- _dA = zerovector (A, promote_scale (ΔC, α))
127- _dA = tensorproduct! (_dA, ΔC,
128- (trivtuple (TO. numind (p)), ()), conjA, E,
129- ((), trivtuple (TO. numind (q))), conjA, (ip, ()),
133+ pE = ((), trivtuple (TO. numind (q)))
134+ pΔC = (trivtuple (TO. numind (p)), ())
135+ TA = promote_scale (ΔC, α)
136+ # TODO : allocator
137+ _dA = tensoralloc_contract (TA, ΔC, pΔC, conjA, E, pE, conjA, pdA, Val (false ))
138+ _dA = tensorproduct! (_dA, ΔC, pΔC, conjA, E, pE, conjA, pdA,
130139 conjA ? α : conj (α), Zero (), ba... )
131140 return projectA (_dA)
132141 end
0 commit comments