Skip to content

Commit 456e1a8

Browse files
authored
[ITensorMPS] Allow customizing apply (#1440)
1 parent 12844c4 commit 456e1a8

File tree

11 files changed

+47
-25
lines changed

11 files changed

+47
-25
lines changed

src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using Adapt: adapt
2-
using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig
2+
using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig, rrule_via_ad
33
using ITensors:
4-
ITensor, apply, dag, hassameinds, inds, inner, itensor, mapprime, replaceprime, swapprime
5-
using ITensors.ITensorMPS: MPO, MPS, siteinds
4+
ITensors, ITensor, dag, hassameinds, inds, itensor, mapprime, replaceprime, swapprime
5+
using ITensors.ITensorMPS: ITensorMPS, MPO, MPS, apply, inner, siteinds
66
using NDTensors: datatype
77

88
function ChainRulesCore.rrule(
@@ -186,7 +186,7 @@ function ChainRulesCore.rrule(
186186
end
187187
y = typeof(x)(y_data)
188188
if !set_limits
189-
y = ITensors.set_ortho_lims(y, ortho_lims(x))
189+
y = ITensorMPS.set_ortho_lims(y, ITensorMPS.ortho_lims(x))
190190
end
191191
return y, map_pullback
192192
end

src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,35 @@
11
using ChainRulesCore: ChainRulesCore, NoTangent
2-
using ITensors: contract, hassameinds, inner, mapprime
2+
using ITensors: Algorithm, contract, hassameinds, inner, mapprime
33
using ITensors.ITensorMPS: MPO, MPS, firstsiteinds, siteinds
44
using LinearAlgebra: tr
55

6-
function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...)
7-
y = contract(x1, x2; kwargs...)
6+
function ChainRulesCore.rrule(
7+
::typeof(contract), alg::Algorithm, x1::MPO, x2::MPO; kwargs...
8+
)
9+
y = contract(alg, x1, x2; kwargs...)
810
function contract_pullback(ȳ)
9-
x̄1 = contract(ȳ, dag(x2); kwargs...)
10-
x̄2 = contract(dag(x1), ȳ; kwargs...)
11-
return (NoTangent(), x̄1, x̄2)
11+
x̄1 = contract(alg, ȳ, dag(x2); kwargs...)
12+
x̄2 = contract(alg, dag(x1), ȳ; kwargs...)
13+
return (NoTangent(), NoTangent(), x̄1, x̄2)
1214
end
1315
return y, contract_pullback
1416
end
1517

16-
function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPS; kwargs...)
17-
y = contract(x1, x2; kwargs...)
18+
function ChainRulesCore.rrule(
19+
::typeof(contract), alg::Algorithm, x1::MPO, x2::MPS; kwargs...
20+
)
21+
y = contract(alg, x1, x2; kwargs...)
1822
function contract_pullback(ȳ)
1923
x̄1 = _contract(MPO, ȳ, dag(x2); kwargs...)
20-
x̄2 = contract(dag(x1), ȳ; kwargs...)
21-
return (NoTangent(), x̄1, x̄2)
24+
x̄2 = contract(alg, dag(x1), ȳ; kwargs...)
25+
return (NoTangent(), NoTangent(), x̄1, x̄2)
2226
end
2327
return y, contract_pullback
2428
end
2529

26-
function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...)
27-
return ChainRulesCore.rrule(contract, x1, x2; kwargs...)
28-
end
30+
## function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; alg, kwargs...)
31+
## return ChainRulesCore.rrule(contract, alg, x1, x2; kwargs...)
32+
## end
2933

3034
function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...)
3135
y = +(x1, x2; kwargs...)

src/lib/ITensorMPS/src/mpo.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,12 @@ Equivalent to `replaceprime(contract(A, x; kwargs...), 2 => 1)`.
596596
597597
See also [`contract`](@ref) for details about the arguments available.
598598
"""
599-
function apply(A::MPO, ψ::MPS; kwargs...)
600-
= contract(A, ψ; kwargs...)
599+
function apply(A::MPO, ψ::MPS; alg=Algorithm"densitymatrix"(), kwargs...)
600+
return apply(Algorithm(alg), A, ψ; kwargs...)
601+
end
602+
603+
function apply(alg::Algorithm, A::MPO, ψ::MPS; kwargs...)
604+
= contract(alg, A, ψ; kwargs...)
601605
return replaceprime(Aψ, 1 => 0)
602606
end
603607

src/lib/ITensorMPS/test/Ops/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
@eval module $(gensym())
12
using ITensors
23
using Test
34

@@ -14,3 +15,4 @@ ITensors.disable_threaded_blocksparse()
1415
@time include(filename)
1516
end
1617
end
18+
end

src/lib/ITensorMPS/test/Ops/test_ops_mpo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
@eval module $(gensym())
12
using Test
23
using ITensors
34
using ITensors.Ops
@@ -95,3 +96,4 @@ end
9596
@test norm(replaceprime(H' * H, 2 => 1) - H²) 0 atol = 1e-14
9697
@test norm(H(H) - H²) 0 atol = 1e-14
9798
end
99+
end

src/lib/ITensorMPS/test/Ops/test_trotter.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
@eval module $(gensym())
12
using Test
23
using ITensors
34
using ITensors.Ops
@@ -36,3 +37,4 @@ end
3637
end
3738
end
3839
end
40+
end

src/lib/ITensorMPS/test/base/test_autompo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
@eval module $(gensym())
12
using ITensors, Test, Random, JLD2
23
using NDTensors: scalartype
34

@@ -1249,3 +1250,4 @@ end
12491250
@test_nowarn H = MPO(os, sites)
12501251
end
12511252
end
1253+
end

src/lib/ITensorMPS/test/base/test_fermions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,11 @@ import ITensors: Out, In
260260
# Reference state |110⟩
261261
ψ110 = MPS(s, n -> n == 1 || n == 2 ? "1" : "0")
262262

263-
function ITensors.op(::OpName"CdagC", ::SiteType, s1::Index, s2::Index)
263+
function ITensors.op(::OpName"CdagC3", ::SiteType, s1::Index, s2::Index)
264264
return op("Cdag", s1) * op("C", s2)
265265
end
266266

267-
os = [("CdagC", 1, 3)]
267+
os = [("CdagC3", 1, 3)]
268268
Os = ops(os, s)
269269

270270
# Results in -|110⟩

src/lib/ITensorMPS/test/base/test_mpo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
@eval module $(gensym())
12
using Combinatorics
23
using ITensors
34
using NDTensors: scalartype
@@ -853,3 +854,4 @@ end
853854
@test maxlinkdim(H) maxlinkdim(H₁) + maxlinkdim(H₂)
854855
end
855856
end
857+
end

src/lib/ITensorMPS/test/base/test_mps.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
@eval module $(gensym())
12
using Combinatorics
23
using ITensors
34
using ITensors: ITensorMPS
@@ -1765,11 +1766,11 @@ end
17651766
# Reference state |110⟩
17661767
ψ110 = MPS(s, n -> n == 1 || n == 2 ? "1" : "0")
17671768

1768-
function ITensors.op(::OpName"CdagC", ::SiteType, s1::Index, s2::Index)
1769+
function ITensors.op(::OpName"CdagC1", ::SiteType, s1::Index, s2::Index)
17691770
return op("Cdag", s1) * op("C", s2)
17701771
end
17711772

1772-
os = [("CdagC", 1, 3)]
1773+
os = [("CdagC1", 1, 3)]
17731774
Os = ops(os, s)
17741775

17751776
# Results in -|110⟩
@@ -1810,7 +1811,7 @@ end
18101811
cutoff!(sweeps, 1E-12)
18111812
energy, ψ0 = dmrg(H, ψ0, sweeps; outputlevel=0)
18121813

1813-
function ITensors.op(::OpName"CdagC", ::SiteType, s1::Index, s2::Index)
1814+
function ITensors.op(::OpName"CdagC2", ::SiteType, s1::Index, s2::Index)
18141815
return op("Cdag", s1) * op("C", s2)
18151816
end
18161817

@@ -1821,7 +1822,7 @@ end
18211822
end
18221823

18231824
for i in 1:(N - 1), j in (i + 1):N
1824-
G1 = op("CdagC", s, i, j)
1825+
G1 = op("CdagC2", s, i, j)
18251826

18261827
@disable_warn_order begin
18271828
G2 = op("Cdag", s, i)
@@ -2006,3 +2007,4 @@ end
20062007
@test norm(M - Mt) 0 atol = 1e-12
20072008
end
20082009
end
2010+
end

0 commit comments

Comments
 (0)