Skip to content

Commit 4e7d189

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents b648353 + fa51083 commit 4e7d189

35 files changed

+110
-357
lines changed

Project.toml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.12.2"
4+
version = "0.13.2"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -13,7 +13,6 @@ Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
1313
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1414
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1515
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
16-
ITensorMPS = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2"
1716
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
1817
IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7"
1918
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
@@ -26,7 +25,6 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2625
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2726
SerializedElementArrays = "d3ce8812-9567-47e9-a7b5-65a6d70a3065"
2827
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
29-
SparseArrayKit = "a9a3c162-d163-4c15-8926-b8794fbefed2"
3028
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
3129
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3230
StructWalk = "31cdf514-beb7-4750-89db-dda9d2eb8d3d"
@@ -63,24 +61,22 @@ DocStringExtensions = "0.9"
6361
EinExprs = "0.6.4"
6462
Graphs = "1.8"
6563
GraphsFlows = "0.1.1"
66-
ITensorMPS = "0.3"
67-
ITensors = "0.7, 0.8"
64+
ITensors = "0.7, 0.8, 0.9"
6865
IsApprox = "0.1, 1, 2"
6966
IterTools = "1.4.0"
70-
KrylovKit = "0.6, 0.7, 0.8"
67+
KrylovKit = "0.6, 0.7, 0.8, 0.9"
7168
MacroTools = "0.5"
7269
NDTensors = "0.3, 0.4"
7370
NamedGraphs = "0.6.0"
7471
OMEinsumContractionOrders = "0.8.3, 0.9"
7572
Observers = "0.2.4"
7673
SerializedElementArrays = "0.1"
7774
SimpleTraits = "0.9"
78-
SparseArrayKit = "0.3, 0.4"
7975
SplitApplyCombine = "1.2"
8076
StaticArrays = "1.5.12"
8177
StructWalk = "0.2"
8278
Suppressor = "0.2"
83-
TensorOperations = "5.1.4"
79+
TensorOperations = "5.2.0"
8480
TimerOutputs = "0.5.22"
8581
TupleTools = "1.4"
8682
julia = "1.10"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
55

66
[compat]
77
Documenter = "1.10.0"
8-
ITensorNetworks = "0.12.0"
8+
ITensorNetworks = "0.13.0"
99
Literate = "2.20.1"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33

44
[compat]
5-
ITensorNetworks = "0.12.0"
5+
ITensorNetworks = "0.13.2"

ext/ITensorNetworksTensorOperationsExt/ITensorNetworksTensorOperationsExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ using TensorOperations: TensorOperations, optimaltree
77

88
function ITensorNetworks.contraction_sequence(::Algorithm"optimal", tn::ITensorList)
99
network = collect.(inds.(tn))
10-
inds_to_dims = Dict(i => dim(i) for i in unique(reduce(vcat, network)))
10+
#Converting dims to Float64 to minimize overflow issues
11+
inds_to_dims = Dict(i => Float64(dim(i)) for i in unique(reduce(vcat, network)))
1112
seq, _ = optimaltree(network, inds_to_dims)
1213
return seq
1314
end

src/ITensorNetworks.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ include("contract_approx/partition.jl")
2323
include("contract_approx/binary_tree_partition.jl")
2424
include("contract.jl")
2525
include("specialitensornetworks.jl")
26-
include("boundarymps.jl")
2726
include("partitioneditensornetwork.jl")
2827
include("edge_sequences.jl")
2928
include("caches/abstractbeliefpropagationcache.jl")

src/abstractitensornetwork.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ using ITensors:
3838
settags,
3939
sim,
4040
swaptags
41-
using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds
4241
using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
4342
using LinearAlgebra: LinearAlgebra, factorize
4443
using MacroTools: @capture
@@ -255,7 +254,7 @@ indsnetwork(tn::AbstractITensorNetwork) = IndsNetwork(tn)
255254

256255
# TODO: Output a `VertexDataGraph`? Unfortunately
257256
# `IndsNetwork` doesn't allow iterating over vertex data.
258-
function ITensorMPS.siteinds(tn::AbstractITensorNetwork)
257+
function siteinds(tn::AbstractITensorNetwork)
259258
is = IndsNetwork(underlying_graph(tn))
260259
for v in vertices(tn)
261260
is[v] = uniqueinds(tn, v)
@@ -268,7 +267,7 @@ function flatten_siteinds(tn::AbstractITensorNetwork)
268267
return identity.(flatten(map(v -> siteinds(tn, v), vertices(tn))))
269268
end
270269

271-
function ITensorMPS.linkinds(tn::AbstractITensorNetwork)
270+
function linkinds(tn::AbstractITensorNetwork)
272271
is = IndsNetwork(underlying_graph(tn))
273272
for e in edges(tn)
274273
is[e] = commoninds(tn, e)
@@ -302,7 +301,11 @@ function ITensors.uniqueinds(tn::AbstractITensorNetwork, edge::Pair)
302301
return uniqueinds(tn, edgetype(tn)(edge))
303302
end
304303

305-
function ITensors.siteinds(tn::AbstractITensorNetwork, vertex)
304+
function siteinds(tn::AbstractITensorNetwork, vertex)
305+
return uniqueinds(tn, vertex)
306+
end
307+
# Fix ambiguity error with IndsNetwork constructor.
308+
function siteinds(tn::AbstractITensorNetwork, vertex::Int)
306309
return uniqueinds(tn, vertex)
307310
end
308311

@@ -311,7 +314,7 @@ function ITensors.commoninds(tn::AbstractITensorNetwork, edge)
311314
return commoninds(tn[src(e)], tn[dst(e)])
312315
end
313316

314-
function ITensorMPS.linkinds(tn::AbstractITensorNetwork, edge)
317+
function linkinds(tn::AbstractITensorNetwork, edge)
315318
return commoninds(tn, edge)
316319
end
317320

@@ -807,24 +810,24 @@ end
807810
# Link dimensions
808811
#
809812

810-
function ITensorMPS.maxlinkdim(tn::AbstractITensorNetwork)
813+
function maxlinkdim(tn::AbstractITensorNetwork)
811814
md = 1
812815
for e in edges(tn)
813816
md = max(md, linkdim(tn, e))
814817
end
815818
return md
816819
end
817820

818-
function ITensorMPS.linkdim(tn::AbstractITensorNetwork, edge::Pair)
821+
function linkdim(tn::AbstractITensorNetwork, edge::Pair)
819822
return linkdim(tn, edgetype(tn)(edge))
820823
end
821824

822-
function ITensorMPS.linkdim(tn::AbstractITensorNetwork{V}, edge::AbstractEdge{V}) where {V}
825+
function linkdim(tn::AbstractITensorNetwork{V}, edge::AbstractEdge{V}) where {V}
823826
ls = linkinds(tn, edge)
824827
return prod([isnothing(l) ? 1 : dim(l) for l in ls])
825828
end
826829

827-
function ITensorMPS.linkdims(tn::AbstractITensorNetwork{V}) where {V}
830+
function linkdims(tn::AbstractITensorNetwork{V}) where {V}
828831
ld = DataGraph{V}(
829832
copy(underlying_graph(tn)); vertex_data_eltype=Nothing, edge_data_eltype=Int
830833
)
@@ -882,7 +885,7 @@ is_multi_edge(tn::AbstractITensorNetwork, e) = length(linkinds(tn, e)) > 1
882885
is_multi_edge(tn::AbstractITensorNetwork) = Base.Fix1(is_multi_edge, tn)
883886

884887
"""Add two itensornetworks together by growing the bond dimension. The network structures need to be have the same vertex names, same site index on each vertex """
885-
function ITensorMPS.add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
888+
function add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
886889
@assert issetequal(vertices(tn1), vertices(tn2))
887890

888891
tn1 = combine_linkinds(tn1; edges=filter(is_multi_edge(tn1), edges(tn1)))

src/apply.jl

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ using ITensors:
2222
replaceinds,
2323
unioninds,
2424
uniqueinds
25-
using ITensors.ContractionSequenceOptimization: optimal_contraction_sequence
26-
using ITensorMPS: siteinds
2725
using KrylovKit: linsolve
2826
using LinearAlgebra: eigen, norm, svd
2927
using NamedGraphs: NamedEdge, has_edge
@@ -407,7 +405,8 @@ function fidelity(
407405
],
408406
envs,
409407
)
410-
term1 = ITensors.contract(term1_tns; sequence=optimal_contraction_sequence(term1_tns))
408+
sequence = contraction_sequence(term1_tns; alg="optimal")
409+
term1 = ITensors.contract(term1_tns; sequence)
411410

412411
term2_tns = vcat(
413412
[
@@ -418,9 +417,11 @@ function fidelity(
418417
],
419418
envs,
420419
)
421-
term2 = ITensors.contract(term2_tns; sequence=optimal_contraction_sequence(term2_tns))
420+
sequence = contraction_sequence(term2_tns; alg="optimal")
421+
term2 = ITensors.contract(term2_tns; sequence)
422422
term3_tns = vcat([p_prev, q_prev, prime(dag(p_cur)), prime(dag(q_cur)), gate], envs)
423-
term3 = ITensors.contract(term3_tns; sequence=optimal_contraction_sequence(term3_tns))
423+
sequence = contraction_sequence(term3_tns; alg="optimal")
424+
term3 = ITensors.contract(term3_tns; sequence)
424425

425426
f = term3[] / sqrt(term1[] * term2[])
426427
return f * conj(f)
@@ -447,61 +448,32 @@ function optimise_p_q(
447448
qs_ind = setdiff(inds(q_cur), collect(Iterators.flatten(inds.(vcat(envs, p_cur)))))
448449
ps_ind = setdiff(inds(p_cur), collect(Iterators.flatten(inds.(vcat(envs, q_cur)))))
449450

450-
opt_b_seq = optimal_contraction_sequence(vcat(ITensor[p, q, o, dag(prime(q_cur))], envs))
451-
opt_b_tilde_seq = optimal_contraction_sequence(
452-
vcat(ITensor[p, q, o, dag(prime(p_cur))], envs)
453-
)
454-
opt_M_seq = optimal_contraction_sequence(
455-
vcat(ITensor[q_cur, replaceinds(prime(dag(q_cur)), prime(qs_ind), qs_ind), p_cur], envs)
456-
)
457-
opt_M_tilde_seq = optimal_contraction_sequence(
458-
vcat(ITensor[p_cur, replaceinds(prime(dag(p_cur)), prime(ps_ind), ps_ind), q_cur], envs)
459-
)
460-
461-
function b(
462-
p::ITensor,
463-
q::ITensor,
464-
o::ITensor,
465-
envs::Vector{ITensor},
466-
r::ITensor;
467-
opt_sequence=nothing,
468-
)
469-
return noprime(
470-
ITensors.contract(vcat(ITensor[p, q, o, dag(prime(r))], envs); sequence=opt_sequence)
471-
)
451+
function b(p::ITensor, q::ITensor, o::ITensor, envs::Vector{ITensor}, r::ITensor)
452+
ts = vcat(ITensor[p, q, o, dag(prime(r))], envs)
453+
sequence = contraction_sequence(ts; alg="optimal")
454+
return noprime(ITensors.contract(ts; sequence))
472455
end
473456

474-
function M_p(
475-
envs::Vector{ITensor},
476-
p_q_tensor::ITensor,
477-
s_ind,
478-
apply_tensor::ITensor;
479-
opt_sequence=nothing,
480-
)
481-
return noprime(
482-
ITensors.contract(
483-
vcat(
484-
ITensor[
485-
p_q_tensor,
486-
replaceinds(prime(dag(p_q_tensor)), prime(s_ind), s_ind),
487-
apply_tensor,
488-
],
489-
envs,
490-
);
491-
sequence=opt_sequence,
492-
),
457+
function M_p(envs::Vector{ITensor}, p_q_tensor::ITensor, s_ind, apply_tensor::ITensor)
458+
ts = vcat(
459+
ITensor[
460+
p_q_tensor, replaceinds(prime(dag(p_q_tensor)), prime(s_ind), s_ind), apply_tensor
461+
],
462+
envs,
493463
)
464+
sequence = contraction_sequence(ts; alg="optimal")
465+
return noprime(ITensors.contract(ts; sequence))
494466
end
495467
for i in 1:nfullupdatesweeps
496-
b_vec = b(p, q, o, envs, q_cur; opt_sequence=opt_b_seq)
497-
M_p_partial = partial(M_p, envs, q_cur, qs_ind; opt_sequence=opt_M_seq)
468+
b_vec = b(p, q, o, envs, q_cur)
469+
M_p_partial = partial(M_p, envs, q_cur, qs_ind)
498470

499471
p_cur, info = linsolve(
500472
M_p_partial, b_vec, p_cur; isposdef=envisposdef, ishermitian=false
501473
)
502474

503-
b_tilde_vec = b(p, q, o, envs, p_cur; opt_sequence=opt_b_tilde_seq)
504-
M_p_tilde_partial = partial(M_p, envs, p_cur, ps_ind; opt_sequence=opt_M_tilde_seq)
475+
b_tilde_vec = b(p, q, o, envs, p_cur)
476+
M_p_tilde_partial = partial(M_p, envs, p_cur, ps_ind)
505477

506478
q_cur, info = linsolve(
507479
M_p_tilde_partial, b_tilde_vec, q_cur; isposdef=envisposdef, ishermitian=false

src/boundarymps.jl

Lines changed: 0 additions & 16 deletions
This file was deleted.

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Graphs: IsDirected
22
using SplitApplyCombine: group
33
using LinearAlgebra: diag, dot
44
using ITensors: dir
5-
using ITensorMPS: ITensorMPS
65
using NamedGraphs.PartitionedGraphs:
76
PartitionedGraphs,
87
PartitionedGraph,
@@ -17,7 +16,7 @@ using NDTensors: NDTensors
1716
abstract type AbstractBeliefPropagationCache end
1817

1918
function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
20-
sequence = optimal_contraction_sequence(contract_list)
19+
sequence = contraction_sequence(contract_list; alg="optimal")
2120
updated_messages = contract(contract_list; sequence, kwargs...)
2221
message_norm = norm(updated_messages)
2322
if normalize && !iszero(message_norm)
@@ -140,7 +139,7 @@ for f in [
140139
:(PartitionedGraphs.partitionvertices),
141140
:(PartitionedGraphs.vertices),
142141
:(PartitionedGraphs.boundary_partitionedges),
143-
:(ITensorMPS.linkinds),
142+
:(linkinds),
144143
]
145144
@eval begin
146145
function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...)

src/caches/beliefpropagationcache.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Graphs: IsDirected
22
using SplitApplyCombine: group
33
using LinearAlgebra: diag, dot
44
using ITensors: dir
5-
using ITensorMPS: ITensorMPS
65
using NamedGraphs.PartitionedGraphs:
76
PartitionedGraphs,
87
PartitionedGraph,
@@ -94,22 +93,16 @@ function environment(bpc::BeliefPropagationCache, verts::Vector; kwargs...)
9493
return vcat(messages, central_tensors)
9594
end
9695

97-
function region_scalar(
98-
bp_cache::BeliefPropagationCache,
99-
pv::PartitionVertex;
100-
contract_kwargs=(; sequence="automatic"),
101-
)
96+
function region_scalar(bp_cache::BeliefPropagationCache, pv::PartitionVertex)
10297
incoming_mts = incoming_messages(bp_cache, [pv])
10398
local_state = factors(bp_cache, pv)
104-
return contract(vcat(incoming_mts, local_state); contract_kwargs...)[]
99+
ts = vcat(incoming_mts, local_state)
100+
sequence = contraction_sequence(ts; alg="optimal")
101+
return contract(ts; sequence)[]
105102
end
106103

107-
function region_scalar(
108-
bp_cache::BeliefPropagationCache,
109-
pe::PartitionEdge;
110-
contract_kwargs=(; sequence="automatic"),
111-
)
112-
return contract(
113-
vcat(message(bp_cache, pe), message(bp_cache, reverse(pe))); contract_kwargs...
114-
)[]
104+
function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
105+
ts = vcat(message(bp_cache, pe), message(bp_cache, reverse(pe)))
106+
sequence = contraction_sequence(ts; alg="optimal")
107+
return contract(ts; sequence)[]
115108
end

0 commit comments

Comments
 (0)