Skip to content

Commit da4a54a

Browse files
authored
[ITensorMPS] Code reorganization (#1406)
1 parent b79f033 commit da4a54a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+268
-212
lines changed

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,22 @@ Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
2929
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3030
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
3131
Zeros = "bd1ec220-6eb4-527a-9b49-e79c3db6233b"
32-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3332

3433
[weakdeps]
34+
# ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3535
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
3636
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
3737
PackageCompiler = "9b87118b-4619-50d2-8e1e-99f35a4d4d9d"
3838
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
39+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3940

4041
[extensions]
42+
# ITensorsChainRulesCoreExt = "ChainRulesCore"
4143
ITensorsHDF5Ext = "HDF5"
4244
ITensorsObserversExt = "Observers"
4345
ITensorsPackageCompilerExt = "PackageCompiler"
4446
ITensorsVectorInterfaceExt = "VectorInterface"
47+
ITensorsZygoteRulesExt = "ZygoteRules"
4548

4649
[compat]
4750
Adapt = "3.5, 4"
@@ -76,7 +79,9 @@ ZygoteRules = "0.2.2"
7679
julia = "1.6"
7780

7881
[extras]
82+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7983
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
8084
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
8185
PackageCompiler = "9b87118b-4619-50d2-8e1e-99f35a4d4d9d"
8286
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
87+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
module ITensorsChainRulesCoreExt
2+
using ChainRulesCore
3+
import ChainRulesCore: rrule
4+
using ITensors
5+
using ITensors: Indices
6+
using ITensors.Adapt
7+
using ITensors.NDTensors
8+
using ITensors.NDTensors: datatype
9+
using ITensors.Ops
10+
include("utils.jl")
11+
include("projection.jl")
12+
include("NDTensors/tensor.jl")
13+
include("NDTensors/dense.jl")
14+
include("indexset.jl")
15+
include("itensor.jl")
16+
include("LazyApply/LazyApply.jl")
17+
include("non_differentiable.jl")
18+
include("itensormps.jl")
19+
include("smallstrings.jl")
20+
end
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
for fname in (
2+
:prime,
3+
:setprime,
4+
:noprime,
5+
:replaceprime,
6+
:swapprime,
7+
:addtags,
8+
:removetags,
9+
:replacetags,
10+
:settags,
11+
:swaptags,
12+
:replaceind,
13+
:replaceinds,
14+
:swapind,
15+
:swapinds,
16+
)
17+
@eval begin
18+
function rrule(f::typeof($fname), x::ITensor, a...; kwargs...)
19+
y = f(x, a...; kwargs...)
20+
function f_pullback(ȳ)
21+
= replaceinds(unthunk(ȳ), inds(y) => inds(x))
22+
ā = map_notangent(a)
23+
return (NoTangent(), x̄, ā...)
24+
end
25+
return y, f_pullback
26+
end
27+
end
28+
end
29+
30+
rrule(::typeof(adjoint), x::ITensor) = rrule(prime, x)
31+
32+
@non_differentiable permute(::Indices, ::Indices)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using ITensors: ITensors
2+
include(
3+
joinpath(
4+
pkgdir(ITensors),
5+
"src",
6+
"lib",
7+
"ITensorMPS",
8+
"ext",
9+
"ITensorMPSChainRulesCoreExt",
10+
"ITensorMPSChainRulesCoreExt.jl",
11+
),
12+
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using ChainRulesCore: @non_differentiable
2+
using ITensors:
3+
ITensors, Index, addtags, commoninds, dag, delta, inds, noncommoninds, onehot, uniqueinds
4+
using ITensors.TagSets: TagSet
5+
6+
@non_differentiable map_notangent(::Any)
7+
@non_differentiable Index(::Any...)
8+
@non_differentiable delta(::Any...)
9+
@non_differentiable dag(::Index)
10+
@non_differentiable inds(::Any...)
11+
@non_differentiable commoninds(::Any...)
12+
@non_differentiable noncommoninds(::Any...)
13+
@non_differentiable uniqueinds(::Any...)
14+
@non_differentiable addtags(::TagSet, ::Any)
15+
@non_differentiable ITensors.filter_inds_set_function(::Function, ::Function, ::Any...)
16+
@non_differentiable ITensors.filter_inds_set_function(::Function, ::Any...)
17+
@non_differentiable ITensors.indpairs(::Any...)
18+
@non_differentiable onehot(::Any...)
19+
@non_differentiable Base.convert(::Type{TagSet}, str::String)

0 commit comments

Comments
 (0)