|
1 | 1 | using ChainRulesCore: ChainRulesCore, NoTangent |
2 | | -using ITensors: contract, hassameinds, inner, mapprime |
| 2 | +using ITensors: Algorithm, contract, hassameinds, inner, mapprime |
3 | 3 | using ITensors.ITensorMPS: MPO, MPS, firstsiteinds, siteinds |
4 | 4 | using LinearAlgebra: tr |
5 | 5 |
|
6 | | -function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...) |
7 | | - y = contract(x1, x2; kwargs...) |
| 6 | +function ChainRulesCore.rrule( |
| 7 | + ::typeof(contract), alg::Algorithm, x1::MPO, x2::MPO; kwargs... |
| 8 | +) |
| 9 | + y = contract(alg, x1, x2; kwargs...) |
8 | 10 | function contract_pullback(ȳ) |
9 | | - x̄1 = contract(ȳ, dag(x2); kwargs...) |
10 | | - x̄2 = contract(dag(x1), ȳ; kwargs...) |
11 | | - return (NoTangent(), x̄1, x̄2) |
| 11 | + x̄1 = contract(alg, ȳ, dag(x2); kwargs...) |
| 12 | + x̄2 = contract(alg, dag(x1), ȳ; kwargs...) |
| 13 | + return (NoTangent(), NoTangent(), x̄1, x̄2) |
12 | 14 | end |
13 | 15 | return y, contract_pullback |
14 | 16 | end |
15 | 17 |
|
16 | | -function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPS; kwargs...) |
17 | | - y = contract(x1, x2; kwargs...) |
| 18 | +function ChainRulesCore.rrule( |
| 19 | + ::typeof(contract), alg::Algorithm, x1::MPO, x2::MPS; kwargs... |
| 20 | +) |
| 21 | + y = contract(alg, x1, x2; kwargs...) |
18 | 22 | function contract_pullback(ȳ) |
19 | 23 | x̄1 = _contract(MPO, ȳ, dag(x2); kwargs...) |
20 | | - x̄2 = contract(dag(x1), ȳ; kwargs...) |
21 | | - return (NoTangent(), x̄1, x̄2) |
| 24 | + x̄2 = contract(alg, dag(x1), ȳ; kwargs...) |
| 25 | + return (NoTangent(), NoTangent(), x̄1, x̄2) |
22 | 26 | end |
23 | 27 | return y, contract_pullback |
24 | 28 | end |
25 | 29 |
|
26 | | -function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) |
27 | | - return ChainRulesCore.rrule(contract, x1, x2; kwargs...) |
28 | | -end |
| 30 | +## function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; alg, kwargs...) |
| 31 | +## return ChainRulesCore.rrule(contract, alg, x1, x2; kwargs...) |
| 32 | +## end |
29 | 33 |
|
30 | 34 | function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) |
31 | 35 | y = +(x1, x2; kwargs...) |
|
0 commit comments