Skip to content

Commit 89d383d

Browse files
committed
Getting lazy to work
2 parents 37d9a76 + adb429b commit 89d383d

File tree

7 files changed

+479
-119
lines changed

7 files changed

+479
-119
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.7"
4+
version = "0.1.9"
55

66
[deps]
7+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
910
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
@@ -23,9 +24,11 @@ TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
2324
[extensions]
2425
ITensorNetworksNextTensorOperationsExt = "TensorOperations"
2526
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
27+
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2628
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
2729

2830
[compat]
31+
AbstractTrees = "0.4.5"
2932
Adapt = "4.3"
3033
BackendSelection = "0.1.6"
3134
DataGraphs = "0.2.7"
@@ -39,6 +42,7 @@ NamedGraphs = "0.6.9, 0.7"
3942
SimpleTraits = "0.9.5"
4043
SplitApplyCombine = "1.2.3"
4144
TermInterface = "2"
45+
TypeParameterAccessors = "0.4.4"
4246
WrappedUnions = "0.3"
4347
TensorOperations = "5.3.1"
4448
julia = "1.10"
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
module ITensorNetworksNextTensorOperationsExt
22

33
using BackendSelection: @Algorithm_str, Algorithm
4-
using ITensorBase: inds, dim
4+
using ITensorBase: inds
55
using ITensorNetworksNext: ITensorNetworksNext
6+
using ITensorNetworksNext.LazyNamedDimsArrays: nested_array_to_lazy_multiply
67
using TensorOperations: TensorOperations, optimaltree
78

89
function ITensorNetworksNext.contraction_sequence(::Algorithm"optimal", tn::Vector{<:AbstractArray})
910
network = collect.(inds.(tn))
1011
#Converting dims to Float64 to minimize overflow issues
11-
inds_to_dims = Dict(i => Float64(dim(i)) for i in unique(reduce(vcat, network)))
12+
inds_to_dims = Dict(i => Float64(length(i)) for i in unique(reduce(vcat, network)))
1213
seq, _ = optimaltree(network, inds_to_dims)
13-
return seq
14+
return nested_array_to_lazy_multiply(seq)
1415
end
1516

1617
end

src/contractnetwork.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using BackendSelection: @Algorithm_str, Algorithm
2+
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, nested_array_to_lazy_multiply, substitute_lazy, materialize
23

34
default_contract_alg = nothing
45

@@ -10,26 +11,20 @@ function set_default_kwargs(alg::Algorithm"exact")
1011
end
1112

1213
function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray})
13-
return Any[i for i in 1:length(tn)]
14+
return nested_array_to_lazy_multiply(collect.(1:length(tn)))
1415
end
1516

1617
function contraction_sequence(tn::Vector{<:AbstractArray}; alg = default_sequence_alg)
1718
return contraction_sequence(Algorithm(alg), tn)
1819
end
1920

20-
# Internal recursive worker
21-
function recursive_contractnetwork(tn::Union{AbstractVector, AbstractNamedDimsArray})
22-
tn isa AbstractVector && return prod(recursive_contractnetwork, tn)
23-
return tn
24-
end
25-
26-
# Recursive worker for ordering the tensors according to the sequence
27-
rearrange(tn::Vector{<:AbstractArray}, i::Integer) = tn[i]
28-
rearrange(tn::Vector{<:AbstractArray}, v::AbstractVector) = [rearrange(tn, s) for s in v]
29-
3021
function contractnetwork(alg::Algorithm"exact", tn::Vector{<:AbstractArray})
3122
contract_sequence = isa(alg.sequence, String) ? contraction_sequence(tn; alg = alg.sequence) : sequence
32-
return recursive_contractnetwork(rearrange(tn, contract_sequence))
23+
@show contract_sequence
24+
@show materialize(contract_sequence)
25+
contract_sequence = substitute_lazy(contract_sequence, Dict(i => lazy(tn[i]) for i in 1:length(tn)))
26+
@show contract_sequence
27+
#return materialize(substitute_lazy(contract_sequence, Dict(i => tn[i] for i in 1:length(tn))))
3328
end
3429

3530
function contractnetwork(alg::Algorithm"exact", tn::AbstractTensorNetwork)

0 commit comments

Comments
 (0)