diff --git a/Project.toml b/Project.toml index 0676244..514e79c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,31 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.0" +version = "0.1.1" + +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" +DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" +Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" +NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" +SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" [compat] +Adapt = "4.3.0" +BackendSelection = "0.1.6" +DataGraphs = "0.2.7" +Dictionaries = "0.4.5" +Graphs = "1.13.1" +LinearAlgebra = "1.10" +MacroTools = "0.5.16" +NamedDimsArrays = "0.7.13" +NamedGraphs = "0.6.9" +SimpleTraits = "0.9.5" +SplitApplyCombine = "1.2.3" julia = "1.10" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 0d76d64..89daa37 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,5 +1,6 @@ module ITensorNetworksNext -# Write your package code here. +include("abstracttensornetwork.jl") +include("tensornetwork.jl") end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl new file mode 100644 index 0000000..e666e93 --- /dev/null +++ b/src/abstracttensornetwork.jl @@ -0,0 +1,279 @@ +using Adapt: Adapt, adapt, adapt_structure +using BackendSelection: @Algorithm_str, Algorithm +using DataGraphs: + DataGraphs, + AbstractDataGraph, + edge_data, + underlying_graph, + underlying_graph_type, + vertex_data +using Dictionaries: Dictionary +using Graphs: + Graphs, + AbstractEdge, + AbstractGraph, + Graph, + add_edge!, + add_vertex!, + bfs_tree, + center, + dst, + edges, + edgetype, + ne, + neighbors, + nv, + rem_edge!, + src, + vertices +using LinearAlgebra: LinearAlgebra, factorize +using MacroTools: @capture +using NamedDimsArrays: dimnames +using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree +using NamedGraphs.GraphsExtensions: + ⊔, directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype +using SplitApplyCombine: flatten + +abstract type AbstractTensorNetwork{V,VD} <: AbstractDataGraph{V,VD,Nothing} end + +function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) + rem_edge!(underlying_graph(tn), e) + return tn +end + +# TODO: Define a generic fallback for `AbstractDataGraph`? +DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data") + +# Graphs.jl overloads +function Graphs.weights(graph::AbstractTensorNetwork) + V = vertextype(graph) + es = Tuple.(edges(graph)) + ws = Dictionary{Tuple{V,V},Float64}(es, undef) + for e in edges(graph) + w = log2(dim(commoninds(graph, e))) + ws[(src(e), dst(e))] = w + end + return ws +end + +# Copy +Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") + +# Iteration +Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) + +# TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition, +# where it is defined as the `vertextype`. Does that cause problems or should it be changed? +Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn)) + +# Overload if needed +Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false + +# Derived interface, may need to be overloaded +function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) + return underlying_graph_type(data_graph_type(G)) +end + +# AbstractDataGraphs overloads +function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) + return error("Not implemented") +end +function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) + return error("Not implemented") +end + +DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") +function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) + return NamedGraphs.vertex_positions(underlying_graph(tn)) +end +function NamedGraphs.ordered_vertices(tn::AbstractTensorNetwork) + return NamedGraphs.ordered_vertices(underlying_graph(tn)) +end + +function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) + # TODO: Define and use: + # + # @preserve_graph map_vertex_data(adapt(to), tn) + # + # or just: + # + # @preserve_graph map(adapt(to), tn) + return map_vertex_data_preserve_graph(adapt(to), tn) +end + +function linkinds(tn::AbstractTensorNetwork, edge::Pair) + return linkinds(tn, edgetype(tn)(edge)) +end +function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) + return nameddimsindices(tn[src(edge)]) ∩ nameddimsindices(tn[dst(edge)]) +end +function linkaxes(tn::AbstractTensorNetwork, edge::Pair) + return linkaxes(tn, edgetype(tn)(edge)) +end +function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) + return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) +end +function linknames(tn::AbstractTensorNetwork, edge::Pair) + return linknames(tn, edgetype(tn)(edge)) +end +function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) + return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) +end + +function siteinds(tn::AbstractTensorNetwork, v) + s = nameddimsindices(tn[v]) + for v′ in neighbors(tn, v) + s = setdiff(s, nameddimsindices(tn[v′])) + end + return s +end +function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) + s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) + for v′ in neighbors(tn, v) + s = setdiff(s, axes(tn[v′])) + end + return s +end +function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) + s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) + for v′ in neighbors(tn, v) + s = setdiff(s, dimnames(tn[v′])) + end + return s +end + +function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) + vertex_data(tn)[vertex] = value + return tn +end + +# TODO: Move to `BaseExtensions` module. +function is_setindex!_expr(expr::Expr) + return is_assignment_expr(expr) && is_getindex_expr(first(expr.args)) +end +is_setindex!_expr(x) = false +is_getindex_expr(expr::Expr) = (expr.head === :ref) +is_getindex_expr(x) = false +is_assignment_expr(expr::Expr) = (expr.head === :(=)) +is_assignment_expr(expr) = false + +# TODO: Define this in terms of a function mapping +# preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph +# preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph +# Also allow annotating codeblocks like `@views`. +macro preserve_graph(expr) + if !is_setindex!_expr(expr) + error( + "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", + ) + end + @capture(expr, array_[indices__] = value_) + return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...))) +end + +# Update the graph of the TensorNetwork `tn` to include +# edges that should exist based on the tensor connectivity. +function add_missing_edges!(tn::AbstractTensorNetwork) + foreach(v -> add_missing_edges!(tn, v), vertices(tn)) + return tn +end + +# Update the graph of the TensorNetwork `tn` to include +# edges that should be incident to the vertex `v` +# based on the tensor connectivity. +function add_missing_edges!(tn::AbstractTensorNetwork, v) + for v′ in vertices(tn) + if v ≠ v′ + e = v => v′ + if !isempty(linkinds(tn, e)) + add_edge!(tn, e) + end + end + end + return tn +end + +# Fix the edges of the TensorNetwork `tn` to match +# the tensor connectivity. +function fix_edges!(tn::AbstractTensorNetwork) + foreach(v -> fix_edges!(tn, v), vertices(tn)) + return tn +end +# Fix the edges of the TensorNetwork `tn` to match +# the tensor connectivity at vertex `v`. +function fix_edges!(tn::AbstractTensorNetwork, v) + rem_incident_edges!(tn, v) + rem_edges!(tn, incident_edges(tn, v)) + add_missing_edges!(tn, v) + return tn +end + +# Customization point. +using NamedDimsArrays: AbstractNamedUnitRange, namedunitrange, nametype, randname +function trivial_unitrange(type::Type{<:AbstractUnitRange}) + return Base.oneto(one(eltype(type))) +end +function rand_trivial_namedunitrange( + ::Type{<:AbstractNamedUnitRange{<:Any,R,N}} +) where {R,N} + return namedunitrange(trivial_unitrange(R), randname(N)) +end + +dag(x) = x + +using NamedDimsArrays: nameddimsindices +function insert_trivial_link!(tn, e) + add_edge!(tn, e) + l = rand_trivial_namedunitrange(eltype(nameddimsindices(tn[src(e)]))) + x = similar(tn[src(e)], (l,)) + x[1] = 1 + @preserve_graph tn[src(e)] = tn[src(e)] * x + @preserve_graph tn[dst(e)] = tn[dst(e)] * dag(x) + return tn +end + +function Base.setindex!(tn::AbstractTensorNetwork, value, v) + @preserve_graph tn[v] = value + fix_edges!(tn, v) + return tn +end +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +# Fix ambiguity error. +function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) + graph[vertices(graph)[vertex]] = value + return graph +end +# Fix ambiguity error. +function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) + return error("No edge data.") +end +# Fix ambiguity error. +function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) + return error("No edge data.") +end +using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger +# Fix ambiguity error. +function Base.setindex!( + tn::AbstractTensorNetwork, + value, + edge::Pair{<:OrdinalSuffixedInteger,<:OrdinalSuffixedInteger}, +) + return error("No edge data.") +end + +function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) + println(io, "$(typeof(graph)) with $(nv(graph)) vertices:") + show(io, mime, vertices(graph)) + println(io, "\n") + println(io, "and $(ne(graph)) edge(s):") + for e in edges(graph) + show(io, mime, e) + println(io) + end + println(io) + println(io, "with vertex data:") + show(io, mime, axes.(vertex_data(graph))) + return nothing +end + +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl new file mode 100644 index 0000000..3fd794b --- /dev/null +++ b/src/tensornetwork.jl @@ -0,0 +1,75 @@ +using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph +using Dictionaries: AbstractDictionary, Indices, dictionary +using Graphs: AbstractSimpleGraph +using NamedDimsArrays: AbstractNamedDimsArray, dimnames, nameddimsarray +using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype +using NamedGraphs.GraphsExtensions: arranged_edges, vertextype + +function _TensorNetwork end + +struct TensorNetwork{V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}} <: + AbstractTensorNetwork{V,VD} + underlying_graph::UG + tensors::Tensors + global @inline function _TensorNetwork( + underlying_graph::UG, tensors::Tensors + ) where {V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}} + # This assumes the tensor connectivity matches the graph structure. + return new{V,VD,UG,Tensors}(underlying_graph, tensors) + end +end + +DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) +DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) +function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) + return fieldtype(type, :underlying_graph) +end + +# Determine the graph structure from the tensors. +function TensorNetwork(t::AbstractDictionary) + g = NamedGraph(eachindex(t)) + for v1 in vertices(g) + for v2 in vertices(g) + if v1 ≠ v2 + if !isdisjoint(dimnames(t[v1]), dimnames(t[v2])) + add_edge!(g, v1 => v2) + end + end + end + end + return _TensorNetwork(g, t) +end +function TensorNetwork(tensors::AbstractDict) + return TensorNetwork(Dictionary(tensors)) +end + +function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) + tn = TensorNetwork(tensors) + arranged_edges(tn) ⊆ arranged_edges(graph) || + error("The edges in the tensors do not match the graph structure.") + for e in setdiff(arranged_edges(graph), arranged_edges(tn)) + insert_trivial_link!(tn, e) + end + return tn +end +function TensorNetwork(graph::AbstractGraph, tensors::AbstractDict) + return TensorNetwork(graph, Dictionary(tensors)) +end +function TensorNetwork(f, graph::AbstractGraph) + return TensorNetwork(graph, Dict(v => f(v) for v in vertices(graph))) +end + +function Base.copy(tn::TensorNetwork) + TensorNetwork(copy(underlying_graph(tn)), copy(vertex_data(tn))) +end +TensorNetwork(tn::TensorNetwork) = copy(tn) +TensorNetwork{V}(tn::TensorNetwork{V}) where {V} = copy(tn) +function TensorNetwork{V}(tn::TensorNetwork) where {V} + g′ = convert_vertextype(V, underlying_graph(tn)) + d = vertex_data(tn) + d′ = dictionary(V(k) => d[k] for k in eachindex(d)) + return TensorNetwork(g′, d′) +end + +NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn +NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn) diff --git a/test/Project.toml b/test/Project.toml index a7a25ff..94f32e3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,13 +1,23 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" +NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" +NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -Aqua = "0.8" -ITensorNetworksNext = "0.1" +Aqua = "0.8.14" +Dictionaries = "0.4.5" +Graphs = "1.13.1" +ITensorBase = "0.2.12" +ITensorNetworksNext = "0.1.1" +NamedDimsArrays = "0.7.14" +NamedGraphs = "0.6.8" SafeTestsets = "0.1" -Suppressor = "0.2" +Suppressor = "0.2.8" Test = "1.10" diff --git a/test/test_basics.jl b/test/test_basics.jl index 8351a1a..b00cd4f 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,66 @@ -using ITensorNetworksNext: ITensorNetworksNext +using Dictionaries: Indices +using Graphs: dst, edges, has_edge, ne, nv, src, vertices +# TODO: Move `arranged_edges` to `NamedGraphs.GraphsExtensions`. +using ITensorNetworksNext: TensorNetwork, arranged_edges, linkaxes, linkinds, siteinds +using ITensorBase: Index +using NamedDimsArrays: dimnames +using NamedGraphs.GraphsExtensions: incident_edges +using NamedGraphs.NamedGraphGenerators: named_grid using Test: @test, @testset @testset "ITensorNetworksNext" begin - # Tests go here. + @testset "Construct TensorNetwork product state" begin + dims = (3, 3) + g = named_grid(dims) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + @test nv(tn) == 9 + @test ne(tn) == ne(g) + @test issetequal(vertices(tn), vertices(g)) + @test issetequal(arranged_edges(tn), arranged_edges(g)) + for v in vertices(tn) + @test siteinds(tn, v) == [s[v]] + end + for v1 in vertices(tn) + for v2 in vertices(tn) + v1 == v2 && continue + haslink = !isempty(linkinds(tn, v1 => v2)) + @test haslink == has_edge(tn, v1 => v2) + end + end + for e in edges(tn) + @test isone(length(linkaxes(tn, e))) + end + end + @testset "Construct TensorNetwork partition function" begin + dims = (3, 3) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + tn = TensorNetwork(g) do v + is = map(incident_edges(g, v)) do e + # TODO: Use `dual` on reverse edges. + return haskey(l, e) ? l[e] : l[reverse(e)] + end + return randn(Tuple(is)) + end + @test nv(tn) == 9 + @test ne(tn) == ne(g) + @test issetequal(vertices(tn), vertices(g)) + @test issetequal(arranged_edges(tn), arranged_edges(g)) + for v in vertices(tn) + @test isempty(siteinds(tn, v)) + end + for v1 in vertices(tn) + for v2 in vertices(tn) + v1 == v2 && continue + haslink = !isempty(linkinds(tn, v1 => v2)) + @test haslink == has_edge(tn, v1 => v2) + end + end + for e in edges(tn) + @test isone(length(linkaxes(tn, e))) + end + end end