Skip to content

Commit f3be6c5

Browse files
committed
Updates
1 parent 3ddc3c2 commit f3be6c5

File tree

4 files changed

+14
-19
lines changed

4 files changed

+14
-19
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2121
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
2222

2323
[weakdeps]
24-
ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
2524
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
2625

2726
[extensions]

ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@ module ITensorNetworksNextTensorOperationsExt
22

33
using BackendSelection: @Algorithm_str, Algorithm
44
using NamedDimsArrays: inds
5-
using ITensorNetworksNext: ITensorNetworksNext
6-
using ITensorNetworksNext.LazyNamedDimsArrays: nested_array_to_lazy_multiply
5+
using ITensorNetworksNext: ITensorNetworksNext, contraction_sequence_to_expr
76
using TensorOperations: TensorOperations, optimaltree
87

98
function ITensorNetworksNext.contraction_sequence(::Algorithm"optimal", tn::Vector{<:AbstractArray})
109
network = collect.(inds.(tn))
1110
#Converting dims to Float64 to minimize overflow issues
1211
inds_to_dims = Dict(i => Float64(length(i)) for i in unique(reduce(vcat, network)))
1312
seq, _ = optimaltree(network, inds_to_dims)
14-
return nested_array_to_lazy_multiply(seq)
13+
return contraction_sequence_to_expr(seq)
1514
end
1615

1716
end

src/contractnetwork.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
using BackendSelection: @Algorithm_str, Algorithm
2-
using ITensorNetworksNext.LazyNamedDimsArrays: nested_array_to_lazy_multiply, substitute_lazy, materialize, lazy,
2+
using ITensorNetworksNext.LazyNamedDimsArrays: substitute, materialize, lazy,
33
symnameddims
44

5-
default_contract_alg = nothing
6-
75
#Algorithmic defaults
86
default_sequence(::Algorithm"exact") = "leftassociative"
97
function set_default_kwargs(alg::Algorithm"exact")
108
sequence = get(alg, :sequence, default_sequence(alg))
119
return Algorithm("exact"; sequence)
1210
end
1311

12+
function contraction_sequence_to_expr(seq)
13+
if seq isa AbstractVector
14+
return prod(contraction_sequence_to_expr, seq)
15+
else
16+
return symnameddims(seq)
17+
end
18+
end
19+
1420
function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray})
15-
return nested_array_to_lazy_multiply(collect.(1:length(tn)))
21+
return contraction_sequence_to_expr(collect.(1:length(tn)))
1622
end
1723

1824
function contraction_sequence(tn::Vector{<:AbstractArray}; alg = default_sequence_alg)
@@ -21,15 +27,14 @@ end
2127

2228
function contractnetwork(alg::Algorithm"exact", tn::Vector{<:AbstractArray})
2329
contract_sequence = isa(alg.sequence, String) ? contraction_sequence(tn; alg = alg.sequence) : sequence
24-
contract_sequence = substitute_lazy(contract_sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn)))
30+
contract_sequence = substitute(contract_sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn)))
2531
return materialize(contract_sequence)
2632
end
2733

2834
function contractnetwork(alg::Algorithm"exact", tn::AbstractTensorNetwork)
2935
return contractnetwork(alg, [tn[v] for v in vertices(tn)])
3036
end
3137

32-
function contractnetwork(tn::Union{AbstractTensorNetwork, Vector{<:AbstractArray}}; alg = default_contract_alg, kwargs...)
33-
alg == nothing && error("Must specify an algorithm to contract the network with")
38+
function contractnetwork(tn; alg, kwargs...)
3439
return contractnetwork(set_default_kwargs(Algorithm(alg; kwargs...)), tn)
3540
end

src/lazynameddimsarrays.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -417,12 +417,4 @@ Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) = lazy(a) * lazy(b
417417
Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) = lazy(a) * b
418418
Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) = a * lazy(b)
419419

420-
function contraction_sequence_to_expr(seq)
421-
if seq isa AbstractVector
422-
return prod(contraction_sequence_to_expr, seq)
423-
else
424-
return symnameddims(seq)
425-
end
426-
end
427-
428420
end

0 commit comments

Comments
 (0)