Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworks"
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
version = "0.13.2"
version = "0.13.3"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ include("caches/abstractbeliefpropagationcache.jl")
include("caches/beliefpropagationcache.jl")
include("formnetworks/abstractformnetwork.jl")
include("formnetworks/bilinearformnetwork.jl")
include("formnetworks/linearformnetwork.jl")
include("formnetworks/quadraticformnetwork.jl")
include("contraction_tree_to_graph.jl")
include("gauging.jl")
Expand Down
9 changes: 2 additions & 7 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ function split_index(
end

function inner_network(x::AbstractITensorNetwork, y::AbstractITensorNetwork; kwargs...)
return BilinearFormNetwork(x, y; kwargs...)
return LinearFormNetwork(x, y; kwargs...)
end

function inner_network(
Expand All @@ -760,12 +760,7 @@ function inner_network(
return BilinearFormNetwork(A, x, y; kwargs...)
end

# TODO: We should make this use the QuadraticFormNetwork constructor here.
# Parts of the code (tests relying on norm_sqr being two layer and the gauging code
# which relies on specific message tensors) currently would break in that case so we need to resolve
function norm_sqr_network(ψ::AbstractITensorNetwork)
return disjoint_union("bra" => dag(prime(ψ; sites=[])), "ket" => ψ)
end
norm_sqr_network(ψ::AbstractITensorNetwork) = inner_network(ψ, ψ)

#
# Printing
Expand Down
6 changes: 3 additions & 3 deletions src/expect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ function expect(
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(alg),
cache_construction_kwargs=default_cache_construction_kwargs(alg, inner_network(ψ, ψ)),
cache_construction_kwargs=default_cache_construction_kwargs(alg, QuadraticFormNetwork(ψ)),
kwargs...,
)
ψIψ = inner_network(ψ, ψ)
ψIψ = QuadraticFormNetwork(ψ)
if isnothing(cache!)
cache! = Ref(cache(alg, ψIψ; cache_construction_kwargs...))
end
Expand All @@ -42,7 +42,7 @@ function expect(
end

function expect(alg::Algorithm"exact", ψ::AbstractITensorNetwork, ops; kwargs...)
ψIψ = inner_network(ψ, ψ)
ψIψ = QuadraticFormNetwork(ψ)
return map(op -> expect(ψIψ, op; alg, kwargs...), ops)
end

Expand Down
7 changes: 7 additions & 0 deletions src/formnetworks/abstractformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ function SimilarType.similar_type(f::AbstractFormNetwork)
return typeof(tensornetwork(f))
end

# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph_type`.
function data_graph_type(G::Type{<:AbstractFormNetwork})
return data_graph_type(fieldtype(G, :tensornetwork))
end
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph`.
data_graph(f::AbstractFormNetwork) = data_graph(tensornetwork(f))

function operator_vertices(f::AbstractFormNetwork)
return filter(v -> last(v) == operator_vertex_suffix(f), vertices(f))
end
Expand Down
4 changes: 0 additions & 4 deletions src/formnetworks/bilinearformnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ bra_vertex_suffix(blf::BilinearFormNetwork) = blf.bra_vertex_suffix
ket_vertex_suffix(blf::BilinearFormNetwork) = blf.ket_vertex_suffix
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph`.
tensornetwork(blf::BilinearFormNetwork) = blf.tensornetwork
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph_type`.
data_graph_type(::Type{<:BilinearFormNetwork}) = data_graph_type(tensornetwork(blf))
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph`.
data_graph(blf::BilinearFormNetwork) = data_graph(tensornetwork(blf))

function Base.copy(blf::BilinearFormNetwork)
return BilinearFormNetwork(
Expand Down
53 changes: 53 additions & 0 deletions src/formnetworks/linearformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using ITensors: ITensor, prime

default_dual_link_index_map = prime

struct LinearFormNetwork{
V,TensorNetwork<:AbstractITensorNetwork{V},BraVertexSuffix,KetVertexSuffix
} <: AbstractFormNetwork{V}
tensornetwork::TensorNetwork
bra_vertex_suffix::BraVertexSuffix
ket_vertex_suffix::KetVertexSuffix
end

function LinearFormNetwork(
bra::AbstractITensorNetwork,
ket::AbstractITensorNetwork;
bra_vertex_suffix=default_bra_vertex_suffix(),
ket_vertex_suffix=default_ket_vertex_suffix(),
dual_link_index_map=default_dual_link_index_map,
)
bra_mapped = dual_link_index_map(bra; sites=[])
tn = disjoint_union(bra_vertex_suffix => dag(bra_mapped), ket_vertex_suffix => ket)
return LinearFormNetwork(tn, bra_vertex_suffix, ket_vertex_suffix)
end

function LinearFormNetwork(blf::BilinearFormNetwork)
bra, ket, operator = subgraph(blf, bra_vertices(blf)),
subgraph(blf, ket_vertices(blf)),
subgraph(blf, operator_vertices(blf))
bra_suffix, ket_suffix = bra_vertex_suffix(blf), ket_vertex_suffix(blf)
operator = rename_vertices(v -> bra_vertex_map(blf)(v), operator)
tn = union(bra, ket, operator)
return LinearFormNetwork(tn, bra_suffix, ket_suffix)
end

bra_vertex_suffix(lf::LinearFormNetwork) = lf.bra_vertex_suffix
ket_vertex_suffix(lf::LinearFormNetwork) = lf.ket_vertex_suffix
# TODO: Use `NamedGraphs.GraphsExtensions.parent_graph`.
tensornetwork(lf::LinearFormNetwork) = lf.tensornetwork

function Base.copy(lf::LinearFormNetwork)
return LinearFormNetwork(
copy(tensornetwork(lf)), bra_vertex_suffix(lf), ket_vertex_suffix(lf)
)
end

function update(lf::LinearFormNetwork, original_ket_state_vertex, ket_state::ITensor)
lf = copy(lf)
# TODO: Maybe add a check that it really does preserve the graph.
setindex_preserve_graph!(
tensornetwork(lf), ket_state, ket_vertex(blf, original_ket_state_vertex)
)
return lf
end
8 changes: 8 additions & 0 deletions test/test_forms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using NamedGraphs.NamedGraphGenerators: named_grid
using ITensorNetworks:
BeliefPropagationCache,
BilinearFormNetwork,
LinearFormNetwork,
QuadraticFormNetwork,
bra_network,
bra_vertex,
Expand Down Expand Up @@ -35,6 +36,10 @@ using Test: @test, @testset
ψbra = random_tensornetwork(rng, s; link_space=χ)
A = random_tensornetwork(rng, s_operator; link_space=D)

lf = LinearFormNetwork(ψbra, ψket)
@test nv(lf) == nv(ψket) + nv(ψbra)
@test isempty(flatten_siteinds(lf))

blf = BilinearFormNetwork(A, ψbra, ψket)
@test nv(blf) == nv(ψket) + nv(ψbra) + nv(A)
@test isempty(flatten_siteinds(blf))
Expand All @@ -43,6 +48,9 @@ using Test: @test, @testset
@test underlying_graph(operator_network(blf)) == underlying_graph(A)
@test underlying_graph(bra_network(blf)) == underlying_graph(ψbra)

lf = LinearFormNetwork(blf)
@test underlying_graph(ket_network(lf)) == underlying_graph(ψket)

qf = QuadraticFormNetwork(ψket)
@test nv(qf) == 3 * nv(ψket)
@test isempty(flatten_siteinds(qf))
Expand Down
5 changes: 3 additions & 2 deletions test/test_itensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using ITensors:
itensor,
onehot,
order,
prime,
random_itensor,
scalartype,
sim,
Expand All @@ -55,7 +56,7 @@ using ITensorNetworks:
ttn
using LinearAlgebra: factorize
using NamedGraphs: NamedEdge
using NamedGraphs.GraphsExtensions: incident_edges
using NamedGraphs.GraphsExtensions: disjoint_union, incident_edges
using NamedGraphs.NamedGraphGenerators: named_comb_tree, named_grid
using NDTensors: NDTensors, dim
using Random: randn!
Expand Down Expand Up @@ -140,7 +141,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
g = named_grid(dims)
s = siteinds("S=1/2", g)
ψ = ITensorNetwork(v -> "↑", s)
tn = norm_sqr_network(ψ)
tn = disjoint_union("bra" => ψ, "ket" => prime(dag(ψ); sites=[]))
tn_2 = contract(tn, ((1, 2), "ket") => ((1, 2), "bra"))
@test !has_vertex(tn_2, ((1, 2), "ket"))
@test tn_2[((1, 2), "bra")] ≈ tn[((1, 2), "ket")] * tn[((1, 2), "bra")]
Expand Down
Loading