diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 4c49a86..0000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options -style = "blue" -indent = 2 diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index 3f78afc..1525861 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -1,11 +1,14 @@ name: "Format Check" on: - push: - branches: - - 'main' - tags: '*' - pull_request: + pull_request_target: + paths: ['**/*.jl'] + types: [opened, synchronize, reopened, ready_for_review] + +permissions: + contents: read + actions: write + pull-requests: write jobs: format-check: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88bc8b4..3fc4743 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ ci: - skip: [julia-formatter] + skip: [runic] repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -11,7 +11,7 @@ repos: - id: end-of-file-fixer exclude_types: [markdown] # incompatible with Literate.jl -- repo: "https://github.com/domluna/JuliaFormatter.jl" - rev: v2.1.6 +- repo: https://github.com/fredrikekre/runic-pre-commit + rev: v2.0.1 hooks: - - id: "julia-formatter" + - id: runic diff --git a/Project.toml b/Project.toml index ce4f361..189983f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,31 +1,48 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.2" +version = "0.1.11" [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" +TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" +WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" + +[weakdeps] +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" + +[extensions] +ITensorNetworksNextTensorOperationsExt = "TensorOperations" [compat] -Adapt = "4.3.0" +AbstractTrees = "0.4.5" +Adapt = "4.3" BackendSelection = "0.1.6" DataGraphs = "0.2.7" Dictionaries = "0.4.5" Graphs = "1.13.1" +ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.7.13" -NamedGraphs = "0.6.9" +NamedDimsArrays = "0.8" +NamedGraphs = "0.6.9, 0.7" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" +TensorOperations = "5.3.1" +TermInterface = "2" +TypeParameterAccessors = "0.4.4" +WrappedUnions = "0.3" julia = "1.10" diff --git a/docs/make.jl b/docs/make.jl index 5a50658..1b29518 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,23 +2,23 @@ using ITensorNetworksNext: ITensorNetworksNext using Documenter: Documenter, DocMeta, deploydocs, makedocs DocMeta.setdocmeta!( - ITensorNetworksNext, :DocTestSetup, :(using ITensorNetworksNext); recursive=true + ITensorNetworksNext, :DocTestSetup, :(using ITensorNetworksNext); recursive = true ) include("make_index.jl") makedocs(; - modules=[ITensorNetworksNext], - authors="ITensor developers and contributors", - sitename="ITensorNetworksNext.jl", - format=Documenter.HTML(; - canonical="https://itensor.github.io/ITensorNetworksNext.jl", - edit_link="main", - assets=["assets/favicon.ico", "assets/extras.css"], - ), - pages=["Home" => "index.md", "Reference" => "reference.md"], + modules = [ITensorNetworksNext], + authors = "ITensor developers and contributors", + sitename = "ITensorNetworksNext.jl", + format = Documenter.HTML(; + canonical = "https://itensor.github.io/ITensorNetworksNext.jl", + edit_link = "main", + assets = ["assets/favicon.ico", "assets/extras.css"], + ), + pages = ["Home" => "index.md", "Reference" => "reference.md"], ) deploydocs(; - repo="github.com/ITensor/ITensorNetworksNext.jl", devbranch="main", push_preview=true + repo = "github.com/ITensor/ITensorNetworksNext.jl", devbranch = "main", push_preview = true ) diff --git a/docs/make_index.jl b/docs/make_index.jl index 44fa493..038bc87 100644 --- a/docs/make_index.jl +++ b/docs/make_index.jl @@ -2,20 +2,20 @@ using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ ```@raw html Flatiron Center for Computational Quantum Physics logo. Flatiron Center for Computational Quantum Physics logo. ``` """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(ITensorNetworksNext), "examples", "README.jl"), - joinpath(pkgdir(ITensorNetworksNext), "docs", "src"); - flavor=Literate.DocumenterFlavor(), - name="index", - postprocess=ccq_logo, + joinpath(pkgdir(ITensorNetworksNext), "examples", "README.jl"), + joinpath(pkgdir(ITensorNetworksNext), "docs", "src"); + flavor = Literate.DocumenterFlavor(), + name = "index", + postprocess = ccq_logo, ) diff --git a/docs/make_readme.jl b/docs/make_readme.jl index 960d376..088dc58 100644 --- a/docs/make_readme.jl +++ b/docs/make_readme.jl @@ -2,20 +2,20 @@ using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ Flatiron Center for Computational Quantum Physics logo. """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(ITensorNetworksNext), "examples", "README.jl"), - joinpath(pkgdir(ITensorNetworksNext)); - flavor=Literate.CommonMarkFlavor(), - name="README", - postprocess=ccq_logo, + joinpath(pkgdir(ITensorNetworksNext), "examples", "README.jl"), + joinpath(pkgdir(ITensorNetworksNext)); + flavor = Literate.CommonMarkFlavor(), + name = "README", + postprocess = ccq_logo, ) diff --git a/examples/README.jl b/examples/README.jl index 4aaa79b..e3ee854 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -1,5 +1,5 @@ # # ITensorNetworksNext.jl -# +# # [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://itensor.github.io/ITensorNetworksNext.jl/stable/) # [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://itensor.github.io/ITensorNetworksNext.jl/dev/) # [![Build Status](https://github.com/ITensor/ITensorNetworksNext.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/ITensor/ITensorNetworksNext.jl/actions/workflows/Tests.yml?query=branch%3Amain) diff --git a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl new file mode 100644 index 0000000..f3b90bf --- /dev/null +++ b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl @@ -0,0 +1,16 @@ +module ITensorNetworksNextTensorOperationsExt + +using BackendSelection: @Algorithm_str, Algorithm +using NamedDimsArrays: inds +using ITensorNetworksNext: ITensorNetworksNext, contraction_sequence_to_expr +using TensorOperations: TensorOperations, optimaltree + +function ITensorNetworksNext.contraction_sequence(::Algorithm"optimal", tn::Vector{<:AbstractArray}) + network = collect.(inds.(tn)) + #Converting dims to Float64 to minimize overflow issues + inds_to_dims = Dict(i => Float64(length(i)) for i in unique(reduce(vcat, network))) + seq, _ = optimaltree(network, inds_to_dims) + return contraction_sequence_to_expr(seq) +end + +end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 89daa37..8cc4dd0 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,6 +1,14 @@ module ITensorNetworksNext +include("lazynameddimsarrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") +include("contract_network.jl") +include("abstract_problem.jl") +include("iterators.jl") + +include("beliefpropagation/abstractbeliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationproblem.jl") end diff --git a/src/abstract_problem.jl b/src/abstract_problem.jl new file mode 100644 index 0000000..5a65e0a --- /dev/null +++ b/src/abstract_problem.jl @@ -0,0 +1 @@ +abstract type AbstractProblem end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index e666e93..3cd2533 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,44 +1,44 @@ using Adapt: Adapt, adapt, adapt_structure using BackendSelection: @Algorithm_str, Algorithm using DataGraphs: - DataGraphs, - AbstractDataGraph, - edge_data, - underlying_graph, - underlying_graph_type, - vertex_data + DataGraphs, + AbstractDataGraph, + edge_data, + underlying_graph, + underlying_graph_type, + vertex_data using Dictionaries: Dictionary using Graphs: - Graphs, - AbstractEdge, - AbstractGraph, - Graph, - add_edge!, - add_vertex!, - bfs_tree, - center, - dst, - edges, - edgetype, - ne, - neighbors, - nv, - rem_edge!, - src, - vertices + Graphs, + AbstractEdge, + AbstractGraph, + Graph, + add_edge!, + add_vertex!, + bfs_tree, + center, + dst, + edges, + edgetype, + ne, + neighbors, + nv, + rem_edge!, + src, + vertices using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture -using NamedDimsArrays: dimnames +using NamedDimsArrays: dimnames, inds using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree using NamedGraphs.GraphsExtensions: - ⊔, directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype + ⊔, directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype using SplitApplyCombine: flatten -abstract type AbstractTensorNetwork{V,VD} <: AbstractDataGraph{V,VD,Nothing} end +abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) - rem_edge!(underlying_graph(tn), e) - return tn + rem_edge!(underlying_graph(tn), e) + return tn end # TODO: Define a generic fallback for `AbstractDataGraph`? @@ -46,14 +46,14 @@ DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge da # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) - V = vertextype(graph) - es = Tuple.(edges(graph)) - ws = Dictionary{Tuple{V,V},Float64}(es, undef) - for e in edges(graph) - w = log2(dim(commoninds(graph, e))) - ws[(src(e), dst(e))] = w - end - return ws + V = vertextype(graph) + es = Tuple.(edges(graph)) + ws = Dictionary{Tuple{V, V}, Float64}(es, undef) + for e in edges(graph) + w = log2(dim(commoninds(graph, e))) + ws[(src(e), dst(e))] = w + end + return ws end # Copy @@ -71,85 +71,85 @@ Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false # Derived interface, may need to be overloaded function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) - return underlying_graph_type(data_graph_type(G)) + return underlying_graph_type(data_graph_type(G)) end # AbstractDataGraphs overloads function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") + return error("Not implemented") end function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") + return error("Not implemented") end DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) - return NamedGraphs.vertex_positions(underlying_graph(tn)) + return NamedGraphs.vertex_positions(underlying_graph(tn)) end function NamedGraphs.ordered_vertices(tn::AbstractTensorNetwork) - return NamedGraphs.ordered_vertices(underlying_graph(tn)) + return NamedGraphs.ordered_vertices(underlying_graph(tn)) end function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) - # TODO: Define and use: - # - # @preserve_graph map_vertex_data(adapt(to), tn) - # - # or just: - # - # @preserve_graph map(adapt(to), tn) - return map_vertex_data_preserve_graph(adapt(to), tn) + # TODO: Define and use: + # + # @preserve_graph map_vertex_data(adapt(to), tn) + # + # or just: + # + # @preserve_graph map(adapt(to), tn) + return map_vertex_data_preserve_graph(adapt(to), tn) end function linkinds(tn::AbstractTensorNetwork, edge::Pair) - return linkinds(tn, edgetype(tn)(edge)) + return linkinds(tn, edgetype(tn)(edge)) end function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return nameddimsindices(tn[src(edge)]) ∩ nameddimsindices(tn[dst(edge)]) + return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) end function linkaxes(tn::AbstractTensorNetwork, edge::Pair) - return linkaxes(tn, edgetype(tn)(edge)) + return linkaxes(tn, edgetype(tn)(edge)) end function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) - return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) + return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) end function linknames(tn::AbstractTensorNetwork, edge::Pair) - return linknames(tn, edgetype(tn)(edge)) + return linknames(tn, edgetype(tn)(edge)) end function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) - return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) + return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) end function siteinds(tn::AbstractTensorNetwork, v) - s = nameddimsindices(tn[v]) - for v′ in neighbors(tn, v) - s = setdiff(s, nameddimsindices(tn[v′])) - end - return s + s = inds(tn[v]) + for v′ in neighbors(tn, v) + s = setdiff(s, inds(tn[v′])) + end + return s end function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) - s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) - for v′ in neighbors(tn, v) - s = setdiff(s, axes(tn[v′])) - end - return s + s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) + for v′ in neighbors(tn, v) + s = setdiff(s, axes(tn[v′])) + end + return s end function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) - s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) - for v′ in neighbors(tn, v) - s = setdiff(s, dimnames(tn[v′])) - end - return s + s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) + for v′ in neighbors(tn, v) + s = setdiff(s, dimnames(tn[v′])) + end + return s end function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) - vertex_data(tn)[vertex] = value - return tn + vertex_data(tn)[vertex] = value + return tn end # TODO: Move to `BaseExtensions` module. function is_setindex!_expr(expr::Expr) - return is_assignment_expr(expr) && is_getindex_expr(first(expr.args)) + return is_assignment_expr(expr) && is_getindex_expr(first(expr.args)) end is_setindex!_expr(x) = false is_getindex_expr(expr::Expr) = (expr.head === :ref) @@ -162,118 +162,117 @@ is_assignment_expr(expr) = false # preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph # Also allow annotating codeblocks like `@views`. macro preserve_graph(expr) - if !is_setindex!_expr(expr) - error( - "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", - ) - end - @capture(expr, array_[indices__] = value_) - return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...))) + if !is_setindex!_expr(expr) + error( + "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", + ) + end + @capture(expr, array_[indices__] = value_) + return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...))) end # Update the graph of the TensorNetwork `tn` to include # edges that should exist based on the tensor connectivity. function add_missing_edges!(tn::AbstractTensorNetwork) - foreach(v -> add_missing_edges!(tn, v), vertices(tn)) - return tn + foreach(v -> add_missing_edges!(tn, v), vertices(tn)) + return tn end # Update the graph of the TensorNetwork `tn` to include # edges that should be incident to the vertex `v` # based on the tensor connectivity. function add_missing_edges!(tn::AbstractTensorNetwork, v) - for v′ in vertices(tn) - if v ≠ v′ - e = v => v′ - if !isempty(linkinds(tn, e)) - add_edge!(tn, e) - end + for v′ in vertices(tn) + if v ≠ v′ + e = v => v′ + if !isempty(linkinds(tn, e)) + add_edge!(tn, e) + end + end end - end - return tn + return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity. function fix_edges!(tn::AbstractTensorNetwork) - foreach(v -> fix_edges!(tn, v), vertices(tn)) - return tn + foreach(v -> fix_edges!(tn, v), vertices(tn)) + return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. function fix_edges!(tn::AbstractTensorNetwork, v) - rem_incident_edges!(tn, v) - rem_edges!(tn, incident_edges(tn, v)) - add_missing_edges!(tn, v) - return tn + rem_incident_edges!(tn, v) + rem_edges!(tn, incident_edges(tn, v)) + add_missing_edges!(tn, v) + return tn end # Customization point. using NamedDimsArrays: AbstractNamedUnitRange, namedunitrange, nametype, randname function trivial_unitrange(type::Type{<:AbstractUnitRange}) - return Base.oneto(one(eltype(type))) + return Base.oneto(one(eltype(type))) end function rand_trivial_namedunitrange( - ::Type{<:AbstractNamedUnitRange{<:Any,R,N}} -) where {R,N} - return namedunitrange(trivial_unitrange(R), randname(N)) + ::Type{<:AbstractNamedUnitRange{<:Any, R, N}} + ) where {R, N} + return namedunitrange(trivial_unitrange(R), randname(N)) end dag(x) = x -using NamedDimsArrays: nameddimsindices function insert_trivial_link!(tn, e) - add_edge!(tn, e) - l = rand_trivial_namedunitrange(eltype(nameddimsindices(tn[src(e)]))) - x = similar(tn[src(e)], (l,)) - x[1] = 1 - @preserve_graph tn[src(e)] = tn[src(e)] * x - @preserve_graph tn[dst(e)] = tn[dst(e)] * dag(x) - return tn + add_edge!(tn, e) + l = rand_trivial_namedunitrange(eltype(inds(tn[src(e)]))) + x = similar(tn[src(e)], (l,)) + x[1] = 1 + @preserve_graph tn[src(e)] = tn[src(e)] * x + @preserve_graph tn[dst(e)] = tn[dst(e)] * dag(x) + return tn end function Base.setindex!(tn::AbstractTensorNetwork, value, v) - @preserve_graph tn[v] = value - fix_edges!(tn, v) - return tn + @preserve_graph tn[v] = value + fix_edges!(tn, v) + return tn end using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) - graph[vertices(graph)[vertex]] = value - return graph + graph[vertices(graph)[vertex]] = value + return graph end # Fix ambiguity error. function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) - return error("No edge data.") + return error("No edge data.") end # Fix ambiguity error. function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) - return error("No edge data.") + return error("No edge data.") end using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!( - tn::AbstractTensorNetwork, - value, - edge::Pair{<:OrdinalSuffixedInteger,<:OrdinalSuffixedInteger}, -) - return error("No edge data.") + tn::AbstractTensorNetwork, + value, + edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, + ) + return error("No edge data.") end function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) - println(io, "$(typeof(graph)) with $(nv(graph)) vertices:") - show(io, mime, vertices(graph)) - println(io, "\n") - println(io, "and $(ne(graph)) edge(s):") - for e in edges(graph) - show(io, mime, e) + println(io, "$(typeof(graph)) with $(nv(graph)) vertices:") + show(io, mime, vertices(graph)) + println(io, "\n") + println(io, "and $(ne(graph)) edge(s):") + for e in edges(graph) + show(io, mime, e) + println(io) + end println(io) - end - println(io) - println(io, "with vertex data:") - show(io, mime, axes.(vertex_data(graph))) - return nothing + println(io, "with vertex data:") + show(io, mime, axes.(vertex_data(graph))) + return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl new file mode 100644 index 0000000..5eae283 --- /dev/null +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -0,0 +1,151 @@ +abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end + +#Interface +factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() +setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() +messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() +function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) + return not_implemented() +end +function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +function rescale_messages( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... + ) + return not_implemented() +end +function rescale_vertices( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... + ) + return not_implemented() +end + +function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return not_implemented() +end +function edge_scalar( + bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... + ) + return not_implemented() +end + +#Graph functionality needed +Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function NamedGraphs.GraphsExtensions.boundary_edges( + bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... + ) + return not_implemented() +end + +#Functions derived from the interface +function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) + for (e, m) in zip(edges) + setmessage!(bp_cache, e, m) + end + return +end + +function deletemessages!( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) + ) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache +end + +function vertex_scalars( + bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... + ) + return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +end + +function edge_scalars( + bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... + ) + return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +end + +function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) + return vertex_scalars(bp_cache), edge_scalars(bp_cache) +end + +function incoming_messages( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] + ) + b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) + b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + return messages(bp_cache, b_edges) +end + +function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return incoming_messages(bp_cache, [vertex]; kwargs...) +end + +#Adapt interface for changing device +function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) + bp_cache = copy(bp_cache) + for e in es + setmessage!(bp_cache, e, f(message(bp_cache, e))) + end + return bp_cache +end +function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) + bp_cache = copy(bp_cache) + for v in vs + setfactor!(bp_cache, v, f(factor(bp_cache, v))) + end + return bp_cache +end +function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_messages(adapt(to), bp_cache, args...) +end +function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_factors(adapt(to), bp_cache, args...) +end + +function freenergy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + any(iszero, denominator_terms) && return -Inf + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end + +function partitionfunction(bp_cache::AbstractBeliefPropagationCache) + return exp(freenergy(bp_cache)) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return rescale_messages(bp_cache, [edge]) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache) + return rescale_messages(bp_cache, edges(bp_cache)) +end + +function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) + return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) +end + +function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) + return rescale_vertices(bpc, [vertex]; kwargs...) +end + +function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + bpc = rescale_messages(bpc) + bpc = rescale_partitions(bpc, args...; kwargs...) + return bpc +end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl new file mode 100644 index 0000000..cdae651 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -0,0 +1,137 @@ +using Dictionaries: Dictionary, set!, delete! +using Graphs: AbstractGraph, is_tree, connected_components +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using ITensorBase: ITensor, dim + +struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: + AbstractBeliefPropagationCache{V} + network::N + messages::Dictionary +end + +messages(bp_cache::BeliefPropagationCache) = bp_cache.messages +network(bp_cache::BeliefPropagationCache) = bp_cache.network + +BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) + +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +end + +function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) + ms = messages(bp_cache) + delete!(ms, e) + return bp_cache +end + +function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) + ms = messages(bp_cache) + set!(ms, e, message) + return bp_cache +end + +function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) + ms = messages(bp_cache) + return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +end + +function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) + return [message(bp_cache, e) for e in edges] +end + +#Forward onto the network +for f in [ + :(Graphs.vertices), + :(Graphs.edges), + :(Graphs.is_tree), + :(NamedGraphs.GraphsExtensions.boundary_edges), + :(factors), + :(default_bp_maxiter), + :(ITensorNetworksNext.setfactor!), + :(ITensorNetworksNext.linkinds), + :(ITensorNetworksNext.underlying_graph), + ] + @eval begin + function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) + return $f(network(bp_cache), args...; kwargs...) + end + end +end + +function factors(tn::AbstractTensorNetwork, vertex) + return [tn[vertex]] +end + +function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] +end + +function region_scalar(bp_cache::BeliefPropagationCache, vertex) + incoming_ms = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + return (reduce(*, incoming_ms) * reduce(*, state))[] +end + +function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return default_message(network(bp_cache), edge::AbstractEdge) +end + +function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) + t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + return t +end + +#TODO: Update message etc should go here... +function updated_message( + alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + vertex = src(edge) + incoming_ms = incoming_messages( + bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] + ) + state = factors(bp_cache, vertex) + #contract_list = ITensor[incoming_ms; state] + #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) + #updated_messages = contract(contract_list; sequence) + updated_message = + !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end + end + return updated_message +end + +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" + ) + return Algorithm("contract"; normalize, sequence_alg) +end +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") + ) + return Algorithm("adapt_update"; adapt, alg) +end + +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) +end + +#Edge sequence stuff +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + forests = forest_cover(g) + edges = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return edges +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl new file mode 100644 index 0000000..a497363 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -0,0 +1,85 @@ +mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: + AbstractProblem + const cache::Cache + diff::Union{Nothing, Float64} +end + +function default_algorithm( + ::Type{<:Algorithm"bp"}, + bpc::BeliefPropagationCache; + verbose = false, + tol = nothing, + edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + message_update_alg = default_algorithm(Algorithm"contract"), + maxiter = is_tree(bpc) ? 1 : nothing, + ) + return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) +end + +function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) + prob = iter.problem + + edge_group, kwargs = current_region_plan(iter) + + new_message_tensors = map(edge_group) do edge + old_message = message(prob.cache, edge) + + new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) + + if !isnothing(prob.diff) + # TODO: Define `message_diff` + prob.diff += message_diff(new_message, old_message) + end + + return new_message + end + + foreach(edge_group, new_message_tensors) do edge, new_message + setmessage!(prob.cache, edge, new_message) + end + + return iter +end + +function region_plan( + prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... + ) + edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + plan = map(edges) do e + return [e] => (; sweep_kwargs...) + end + + return plan +end + +function update(bpc::AbstractBeliefPropagationCache; kwargs...) + return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) +end +function update(alg::Algorithm"bp", bpc) + compute_error = !isnothing(alg.tol) + + diff = compute_error ? 0.0 : nothing + + prob = BeliefPropagationProblem(bpc, diff) + + iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + + for _ in iter + if compute_error && prob.diff <= alg.tol + break + end + end + + if alg.verbose && compute_error + if prob.diff <= alg.tol + println("BP converged to desired precision after $(iter.which_sweep) iterations.") + else + println( + "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", + ) + end + end + + return bpc +end diff --git a/src/contract_network.jl b/src/contract_network.jl new file mode 100644 index 0000000..67d69e0 --- /dev/null +++ b/src/contract_network.jl @@ -0,0 +1,47 @@ +using BackendSelection: @Algorithm_str, Algorithm +using ITensorNetworksNext.LazyNamedDimsArrays: substitute, materialize, lazy, + symnameddims + +#Algorithmic defaults +default_sequence_alg(::Algorithm"exact") = "leftassociative" +default_sequence(::Algorithm"exact") = nothing +function set_default_kwargs(alg::Algorithm"exact") + sequence = get(alg, :sequence, nothing) + sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) + return Algorithm("exact"; sequence, sequence_alg) +end + +function contraction_sequence_to_expr(seq) + if seq isa AbstractVector + return prod(contraction_sequence_to_expr, seq) + else + return symnameddims(seq) + end +end + +function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray}) + return prod(symnameddims, 1:length(tn)) +end + +function contraction_sequence(tn::Vector{<:AbstractArray}; sequence_alg = default_sequence_alg(Algorithm("exact"))) + return contraction_sequence(Algorithm(sequence_alg), tn) +end + +function contract_network(alg::Algorithm"exact", tn::Vector{<:AbstractArray}) + if !isnothing(alg.sequence) + sequence = alg.sequence + else + sequence = contraction_sequence(tn; sequence_alg = alg.sequence_alg) + end + + sequence = substitute(sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn))) + return materialize(sequence) +end + +function contract_network(alg::Algorithm"exact", tn::AbstractTensorNetwork) + return contract_network(alg, [tn[v] for v in vertices(tn)]) +end + +function contract_network(tn; alg, kwargs...) + return contract_network(set_default_kwargs(Algorithm(alg; kwargs...)), tn) +end diff --git a/src/iterators.jl b/src/iterators.jl new file mode 100644 index 0000000..1fe4844 --- /dev/null +++ b/src/iterators.jl @@ -0,0 +1,173 @@ +""" + abstract type AbstractNetworkIterator + +A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins +with a call to `increment!` before executing `compute!`, however the initial call to +`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that +this call is implict. Termination of the iterator is controlled by the function `done`. +""" +abstract type AbstractNetworkIterator end + +# We use greater than or equals here as we increment the state at the start of the iteration +islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) + +function Base.iterate(iterator::AbstractNetworkIterator, init = true) + # The assumption is that first "increment!" is implicit, therefore we must skip the + # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not + # defined when length < 1, + init || islaststep(iterator) && return nothing + # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* + # define a method for increment! This way we avoid cases where one may wish to nest + # calls to different step! methods accidentaly incrementing multiple times. + init || increment!(iterator) + rv = compute!(iterator) + return rv, false +end + +function increment! end +compute!(iterator::AbstractNetworkIterator) = iterator + +step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) +function step!(f, iterator::AbstractNetworkIterator) + compute!(iterator) + f(iterator) + increment!(iterator) + return iterator +end + +# +# RegionIterator +# +""" + struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator +""" +mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator + problem::Problem + region_plan::RegionPlan + which_region::Int + const which_sweep::Int + function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} + if length(region_plan) == 0 + throw(BoundsError("Cannot construct a region iterator with 0 elements.")) + end + return new{P, R}(problem, region_plan, 1, sweep) + end +end + +function RegionIterator(problem; sweep, sweep_kwargs...) + plan = region_plan(problem; sweep_kwargs...) + return RegionIterator(problem, plan, sweep) +end + +function new_region_iterator(iterator::RegionIterator; sweep_kwargs...) + return RegionIterator(iterator.problem; sweep_kwargs...) +end + +state(region_iter::RegionIterator) = region_iter.which_region +Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) + +problem(region_iter::RegionIterator) = region_iter.problem + +function current_region_plan(region_iter::RegionIterator) + return region_iter.region_plan[region_iter.which_region] +end + +function current_region(region_iter::RegionIterator) + region, _ = current_region_plan(region_iter) + return region +end + +function region_kwargs(region_iter::RegionIterator) + _, kwargs = current_region_plan(region_iter) + return kwargs +end +function region_kwargs(f::Function, iter::RegionIterator) + return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) +end + +function prev_region(region_iter::RegionIterator) + state(region_iter) <= 1 && return nothing + prev, _ = region_iter.region_plan[region_iter.which_region - 1] + return prev +end + +function next_region(region_iter::RegionIterator) + islaststep(region_iter) && return nothing + next, _ = region_iter.region_plan[region_iter.which_region + 1] + return next +end + +# +# Functions associated with RegionIterator +# +function increment!(region_iter::RegionIterator) + region_iter.which_region += 1 + return region_iter +end + +function compute!(iter::RegionIterator) + _, local_state = extract!(iter; region_kwargs(extract!, iter)...) + _, local_state = update!(iter, local_state; region_kwargs(update!, iter)...) + insert!(iter, local_state; region_kwargs(insert!, iter)...) + + return iter +end + +region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) + +# +# SweepIterator +# + +mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator + region_iter::RegionIterator{Problem} + sweep_kwargs::Iterators.Stateful{Iter} + which_sweep::Int + function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} + stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) + first_state = Iterators.peel(stateful_sweep_kwargs) + + if isnothing(first_state) + throw(BoundsError("Cannot construct a sweep iterator with 0 elements.")) + end + + first_kwargs, _ = first_state + region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) + + return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) + end +end + +islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) + +region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter +problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) + +state(sweep_iter::SweepIterator) = sweep_iter.which_sweep +Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) +function increment!(sweep_iter::SweepIterator) + sweep_iter.which_sweep += 1 + sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) + update_region_iterator!(sweep_iter; sweep_kwargs...) + return sweep_iter +end + +function update_region_iterator!(iterator::SweepIterator; kwargs...) + sweep = state(iterator) + iterator.region_iter = new_region_iterator(iterator.region_iter; sweep, kwargs...) + return iterator +end + +function compute!(sweep_iter::SweepIterator) + for _ in sweep_iter.region_iter + # TODO: Is it sensible to execute the default region callback function? + end + return +end + +# More basic constructor where sweep_kwargs are constant throughout sweeps +function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) + # Initialize this to an empty RegionIterator + sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) + return SweepIterator(problem, sweep_kwargs_iter) +end diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl new file mode 100644 index 0000000..23e0679 --- /dev/null +++ b/src/lazynameddimsarrays.jl @@ -0,0 +1,420 @@ +module LazyNamedDimsArrays + +using AbstractTrees: AbstractTrees +using WrappedUnions: @wrapped, unwrap +using NamedDimsArrays: + NamedDimsArrays, + AbstractNamedDimsArray, + AbstractNamedDimsArrayStyle, + NamedDimsArray, + dename, + dimnames, + inds +using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments +using TypeParameterAccessors: unspecify_type_parameters + +lazy(x) = error("Not defined.") + +generic_map(f, v) = map(f, v) +generic_map(f, v::AbstractDict) = Dict(eachindex(v) .=> map(f, values(v))) +generic_map(f, v::AbstractSet) = Set([f(x) for x in v]) + +# Defined to avoid type piracy. +# TODO: Define a proper hash function +# in NamedDimsArrays.jl, maybe one that is +# independent of the order of dimensions. +function _hash(a::NamedDimsArray, h::UInt64) + h = hash(:NamedDimsArray, h) + h = hash(dename(a), h) + for i in inds(a) + h = hash(i, h) + end + return h +end +function _hash(x, h::UInt64) + return hash(x, h) +end + +# Custom version of `AbstractTrees.printnode` to +# avoid type piracy when overloading on `AbstractNamedDimsArray`. +printnode_nameddims(io::IO, x) = AbstractTrees.printnode(io, x) +function printnode_nameddims(io::IO, a::AbstractNamedDimsArray) + show(io, collect(dimnames(a))) + return nothing +end + +# Generic lazy functionality. +function maketerm_lazy(type::Type, head, args, metadata) + if head ≡ * + return type(maketerm(Mul, head, args, metadata)) + else + return error("Only mul supported right now.") + end +end +function getindex_lazy(a::AbstractArray, I...) + u = unwrap(a) + if !iscall(u) + return u[I...] + else + return error("Indexing into expression not supported.") + end +end +function arguments_lazy(a) + u = unwrap(a) + if !iscall(u) + return error("No arguments.") + elseif ismul(u) + return arguments(u) + else + return error("Variant not supported.") + end +end +function children_lazy(a) + return arguments(a) +end +function head_lazy(a) + return operation(a) +end +function iscall_lazy(a) + return iscall(unwrap(a)) +end +function isexpr_lazy(a) + return iscall(a) +end +function operation_lazy(a) + u = unwrap(a) + if !iscall(u) + return error("No operation.") + elseif ismul(u) + return operation(u) + else + return error("Variant not supported.") + end +end +function sorted_arguments_lazy(a) + u = unwrap(a) + if !iscall(u) + return error("No arguments.") + elseif ismul(u) + return sorted_arguments(u) + else + return error("Variant not supported.") + end +end +function sorted_children_lazy(a) + return sorted_arguments(a) +end +ismul_lazy(a) = ismul(unwrap(a)) +function abstracttrees_children_lazy(a) + if !iscall(a) + return () + else + return arguments(a) + end +end +function nodevalue_lazy(a) + if !iscall(a) + return unwrap(a) + else + return operation(a) + end +end +using Base.Broadcast: materialize +function materialize_lazy(a) + u = unwrap(a) + if !iscall(u) + return u + elseif ismul(u) + return mapfoldl(materialize, operation(u), arguments(u)) + else + return error("Variant not supported.") + end +end +copy_lazy(a) = materialize(a) +function equals_lazy(a1, a2) + u1, u2 = unwrap.((a1, a2)) + if !iscall(u1) && !iscall(u2) + return u1 == u2 + elseif ismul(u1) && ismul(u2) + return arguments(u1) == arguments(u2) + else + return false + end +end +function hash_lazy(a, h::UInt64) + h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) + # Use `_hash`, which defines a custom hash for NamedDimsArray. + return _hash(unwrap(a), h) +end +function map_arguments_lazy(f, a) + u = unwrap(a) + if !iscall(u) + return error("No arguments to map.") + elseif ismul(u) + return lazy(map_arguments(f, u)) + else + return error("Variant not supported.") + end +end +function substitute_lazy(a, substitutions::AbstractDict) + haskey(substitutions, a) && return substitutions[a] + !iscall(a) && return a + return map_arguments(arg -> substitute(arg, substitutions), a) +end +function substitute_lazy(a, substitutions) + return substitute(a, Dict(substitutions)) +end +function printnode_lazy(io, a) + # Use `printnode_nameddims` to avoid type piracy, + # since it overloads on `AbstractNamedDimsArray`. + return printnode_nameddims(io, unwrap(a)) +end +function show_lazy(io::IO, a) + if !iscall(a) + return show(io, unwrap(a)) + else + return AbstractTrees.printnode(io, a) + end +end +function show_lazy(io::IO, mime::MIME"text/plain", a) + summary(io, a) + println(io, ":") + if !iscall(a) + show(io, mime, unwrap(a)) + return nothing + else + show(io, a) + return nothing + end +end +add_lazy(a1, a2) = error("Not implemented.") +sub_lazy(a) = error("Not implemented.") +sub_lazy(a1, a2) = error("Not implemented.") +function mul_lazy(a) + u = unwrap(a) + if !iscall(u) + return lazy(Mul([a])) + elseif ismul(u) + return a + else + return error("Variant not supported.") + end +end +# Note that this is nested by default. +mul_lazy(a1, a2) = lazy(Mul([a1, a2])) +mul_lazy(a1::Number, a2) = error("Not implemented.") +mul_lazy(a1, a2::Number) = error("Not implemented.") +mul_lazy(a1::Number, a2::Number) = a1 * a2 +div_lazy(a1, a2::Number) = error("Not implemented.") + +# NamedDimsArrays.jl interface. +function inds_lazy(a) + u = unwrap(a) + if !iscall(u) + return inds(u) + elseif ismul(u) + return mapreduce(inds, symdiff, arguments(u)) + else + return error("Variant not supported.") + end +end +function dename_lazy(a) + u = unwrap(a) + if !iscall(u) + return dename(u) + else + return error("Variant not supported.") + end +end + +# Lazy broadcasting. +struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...) + return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.") +end +# Linear operations. +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2) = a1 + a2 +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2) = a1 - a2 +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a) = c * a +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number) = a * c +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number) = a * b +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number) = a / c +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) = -a + +# Generic functionality for Applied types, like `Mul`, `Add`, etc. +ismul(a) = operation(a) ≡ * +head_applied(a) = operation(a) +iscall_applied(a) = true +isexpr_applied(a) = iscall(a) +function show_applied(io::IO, a) + args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(a)) + print(io, "(", join(args, " $(operation(a)) "), ")") + return nothing +end +sorted_arguments_applied(a) = arguments(a) +children_applied(a) = arguments(a) +sorted_children_applied(a) = sorted_arguments(a) +function maketerm_applied(type, head, args, metadata) + term = type(args) + @assert head ≡ operation(term) + return term +end +map_arguments_applied(f, a) = unspecify_type_parameters(typeof(a))(map(f, arguments(a))) +function hash_applied(a, h::UInt64) + h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) + for arg in arguments(a) + h = hash(arg, h) + end + return h +end + +abstract type Applied end +TermInterface.head(a::Applied) = head_applied(a) +TermInterface.iscall(a::Applied) = iscall_applied(a) +TermInterface.isexpr(a::Applied) = isexpr_applied(a) +Base.show(io::IO, a::Applied) = show_applied(io, a) +TermInterface.sorted_arguments(a::Applied) = sorted_arguments_applied(a) +TermInterface.children(a::Applied) = children_applied(a) +TermInterface.sorted_children(a::Applied) = sorted_children_applied(a) +function TermInterface.maketerm(type::Type{<:Applied}, head, args, metadata) + return maketerm_applied(type, head, args, metadata) +end +map_arguments(f, a::Applied) = map_arguments_applied(f, a) +Base.hash(a::Applied, h::UInt64) = hash_applied(a, h) + +struct Mul{A} <: Applied + arguments::Vector{A} +end +TermInterface.arguments(m::Mul) = getfield(m, :arguments) +TermInterface.operation(m::Mul) = * + +@wrapped struct LazyNamedDimsArray{ + T, A <: AbstractNamedDimsArray{T}, + } <: AbstractNamedDimsArray{T, Any} + union::Union{A, Mul{LazyNamedDimsArray{T, A}}} +end +function LazyNamedDimsArray(a::AbstractNamedDimsArray) + # Use `eltype(typeof(a))` for arrays that have different + # runtime and compile time eltypes, like `ITensor`. + return LazyNamedDimsArray{eltype(typeof(a)), typeof(a)}(a) +end +function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A} + return LazyNamedDimsArray{T, A}(a) +end +lazy(a::LazyNamedDimsArray) = a +lazy(a::AbstractNamedDimsArray) = LazyNamedDimsArray(a) +lazy(a::Mul{<:LazyNamedDimsArray}) = LazyNamedDimsArray(a) + +NamedDimsArrays.inds(a::LazyNamedDimsArray) = inds_lazy(a) +NamedDimsArrays.dename(a::LazyNamedDimsArray) = dename_lazy(a) + +# Broadcasting +function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray}) + return LazyNamedDimsArrayStyle() +end + +# Derived functionality. +function TermInterface.maketerm(type::Type{LazyNamedDimsArray}, head, args, metadata) + return maketerm_lazy(type, head, args, metadata) +end +Base.getindex(a::LazyNamedDimsArray, I::Int...) = getindex_lazy(a, I...) +TermInterface.arguments(a::LazyNamedDimsArray) = arguments_lazy(a) +TermInterface.children(a::LazyNamedDimsArray) = children_lazy(a) +TermInterface.head(a::LazyNamedDimsArray) = head_lazy(a) +TermInterface.iscall(a::LazyNamedDimsArray) = iscall_lazy(a) +TermInterface.isexpr(a::LazyNamedDimsArray) = isexpr_lazy(a) +TermInterface.operation(a::LazyNamedDimsArray) = operation_lazy(a) +TermInterface.sorted_arguments(a::LazyNamedDimsArray) = sorted_arguments_lazy(a) +AbstractTrees.children(a::LazyNamedDimsArray) = abstracttrees_children_lazy(a) +TermInterface.sorted_children(a::LazyNamedDimsArray) = sorted_children_lazy(a) +ismul(a::LazyNamedDimsArray) = ismul_lazy(a) +AbstractTrees.nodevalue(a::LazyNamedDimsArray) = nodevalue_lazy(a) +Base.Broadcast.materialize(a::LazyNamedDimsArray) = materialize_lazy(a) +Base.copy(a::LazyNamedDimsArray) = copy_lazy(a) +Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = equals_lazy(a1, a2) +Base.hash(a::LazyNamedDimsArray, h::UInt64) = hash_lazy(a, h) +map_arguments(f, a::LazyNamedDimsArray) = map_arguments_lazy(f, a) +substitute(a::LazyNamedDimsArray, substitutions) = substitute_lazy(a, substitutions) +AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a) +printnode_nameddims(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a) +Base.show(io::IO, a::LazyNamedDimsArray) = show_lazy(io, a) +Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) = show_lazy(io, mime, a) +Base.:*(a::LazyNamedDimsArray) = mul_lazy(a) +Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = mul_lazy(a1, a2) +Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = add_lazy(a1, a2) +Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = sub_lazy(a1, a2) +Base.:*(a1::Number, a2::LazyNamedDimsArray) = mul_lazy(a1, a2) +Base.:*(a1::LazyNamedDimsArray, a2::Number) = mul_lazy(a1, a2) +Base.:/(a1::LazyNamedDimsArray, a2::Number) = div_lazy(a1, a2) +Base.:-(a::LazyNamedDimsArray) = sub_lazy(a) + +struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} + name::Name + axes::Axes + function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} + N = length(ax) + return new{T, N, typeof(name), typeof(ax)}(name, ax) + end +end +function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) + return SymbolicArray{Any}(name, ax) +end +function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T} + return SymbolicArray{T}(name, ax) +end +function SymbolicArray(name, ax::AbstractUnitRange...) + return SymbolicArray{Any}(name, ax) +end +symname(a::SymbolicArray) = getfield(a, :name) +Base.axes(a::SymbolicArray) = getfield(a, :axes) +Base.size(a::SymbolicArray) = length.(axes(a)) +function Base.:(==)(a::SymbolicArray, b::SymbolicArray) + return symname(a) == symname(b) && axes(a) == axes(b) +end +function Base.hash(a::SymbolicArray, h::UInt64) + h = hash(:SymbolicArray, h) + h = hash(symname(a), h) + return hash(size(a), h) +end +function Base.getindex(a::SymbolicArray{<:Any, N}, I::Vararg{Int, N}) where {N} + return error("Indexing into SymbolicArray not supported.") +end +function Base.setindex!(a::SymbolicArray{<:Any, N}, value, I::Vararg{Int, N}) where {N} + return error("Indexing into SymbolicArray not supported.") +end +function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) + Base.summary(io, a) + println(io, ":") + print(io, repr(symname(a))) + return nothing +end +function Base.show(io::IO, a::SymbolicArray) + print(io, "SymbolicArray(", symname(a), ", ", size(a), ")") + return nothing +end +using AbstractTrees: AbstractTrees +function AbstractTrees.printnode(io::IO, a::SymbolicArray) + print(io, repr(symname(a))) + return nothing +end +const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = + NamedDimsArray{T, N, Parent, DimNames} +function symnameddims(name) + return lazy(NamedDimsArray(SymbolicArray(name), ())) +end +function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) + print(io, symname(dename(a))) + if ndims(a) > 0 + print(io, "[", join(dimnames(a), ","), "]") + end + return nothing +end +printnode_nameddims(io::IO, a::SymbolicNamedDimsArray) = AbstractTrees.printnode(io, a) +function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) + return issetequal(inds(a), inds(b)) && dename(a) == dename(b) +end +Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) = lazy(a) * lazy(b) +Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) = lazy(a) * b +Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) = a * lazy(b) + +end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 3fd794b..c7d1479 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,74 +1,74 @@ using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: AbstractDictionary, Indices, dictionary using Graphs: AbstractSimpleGraph -using NamedDimsArrays: AbstractNamedDimsArray, dimnames, nameddimsarray +using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype using NamedGraphs.GraphsExtensions: arranged_edges, vertextype function _TensorNetwork end -struct TensorNetwork{V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}} <: - AbstractTensorNetwork{V,VD} - underlying_graph::UG - tensors::Tensors - global @inline function _TensorNetwork( - underlying_graph::UG, tensors::Tensors - ) where {V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}} - # This assumes the tensor connectivity matches the graph structure. - return new{V,VD,UG,Tensors}(underlying_graph, tensors) - end +struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionary{V, VD}} <: + AbstractTensorNetwork{V, VD} + underlying_graph::UG + tensors::Tensors + global @inline function _TensorNetwork( + underlying_graph::UG, tensors::Tensors + ) where {V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionary{V, VD}} + # This assumes the tensor connectivity matches the graph structure. + return new{V, VD, UG, Tensors}(underlying_graph, tensors) + end end DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) - return fieldtype(type, :underlying_graph) + return fieldtype(type, :underlying_graph) end # Determine the graph structure from the tensors. function TensorNetwork(t::AbstractDictionary) - g = NamedGraph(eachindex(t)) - for v1 in vertices(g) - for v2 in vertices(g) - if v1 ≠ v2 - if !isdisjoint(dimnames(t[v1]), dimnames(t[v2])) - add_edge!(g, v1 => v2) + g = NamedGraph(eachindex(t)) + for v1 in vertices(g) + for v2 in vertices(g) + if v1 ≠ v2 + if !isdisjoint(dimnames(t[v1]), dimnames(t[v2])) + add_edge!(g, v1 => v2) + end + end end - end end - end - return _TensorNetwork(g, t) + return _TensorNetwork(g, t) end function TensorNetwork(tensors::AbstractDict) - return TensorNetwork(Dictionary(tensors)) + return TensorNetwork(Dictionary(tensors)) end function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) - tn = TensorNetwork(tensors) - arranged_edges(tn) ⊆ arranged_edges(graph) || - error("The edges in the tensors do not match the graph structure.") - for e in setdiff(arranged_edges(graph), arranged_edges(tn)) - insert_trivial_link!(tn, e) - end - return tn + tn = TensorNetwork(tensors) + arranged_edges(tn) ⊆ arranged_edges(graph) || + error("The edges in the tensors do not match the graph structure.") + for e in setdiff(arranged_edges(graph), arranged_edges(tn)) + insert_trivial_link!(tn, e) + end + return tn end function TensorNetwork(graph::AbstractGraph, tensors::AbstractDict) - return TensorNetwork(graph, Dictionary(tensors)) + return TensorNetwork(graph, Dictionary(tensors)) end function TensorNetwork(f, graph::AbstractGraph) - return TensorNetwork(graph, Dict(v => f(v) for v in vertices(graph))) + return TensorNetwork(graph, Dict(v => f(v) for v in vertices(graph))) end function Base.copy(tn::TensorNetwork) - TensorNetwork(copy(underlying_graph(tn)), copy(vertex_data(tn))) + return TensorNetwork(copy(underlying_graph(tn)), copy(vertex_data(tn))) end TensorNetwork(tn::TensorNetwork) = copy(tn) TensorNetwork{V}(tn::TensorNetwork{V}) where {V} = copy(tn) function TensorNetwork{V}(tn::TensorNetwork) where {V} - g′ = convert_vertextype(V, underlying_graph(tn)) - d = vertex_data(tn) - d′ = dictionary(V(k) => d[k] for k in eachindex(d)) - return TensorNetwork(g′, d′) + g′ = convert_vertextype(V, underlying_graph(tn)) + d = vertex_data(tn) + d′ = dictionary(V(k) => d[k] for k in eachindex(d)) + return TensorNetwork(g′, d′) end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn diff --git a/test/Project.toml b/test/Project.toml index 94f32e3..4c12286 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -8,16 +9,23 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] +AbstractTrees = "0.4.5" Aqua = "0.8.14" Dictionaries = "0.4.5" Graphs = "1.13.1" -ITensorBase = "0.2.12" +ITensorBase = "0.3" ITensorNetworksNext = "0.1.1" -NamedDimsArrays = "0.7.14" -NamedGraphs = "0.6.8" +NamedDimsArrays = "0.8" +NamedGraphs = "0.6.8, 0.7" SafeTestsets = "0.1" Suppressor = "0.2.8" +TermInterface = "2" +TensorOperations = "5.3.1" Test = "1.10" +WrappedUnions = "0.3" diff --git a/test/runtests.jl b/test/runtests.jl index 98b2d2b..0008050 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,60 +6,62 @@ using Suppressor: Suppressor const pat = r"(?:--group=)(\w+)" arg_id = findfirst(contains(pat), ARGS) const GROUP = uppercase( - if isnothing(arg_id) - get(ENV, "GROUP", "ALL") - else - only(match(pat, ARGS[arg_id]).captures) - end, + if isnothing(arg_id) + get(ENV, "GROUP", "ALL") + else + only(match(pat, ARGS[arg_id]).captures) + end, ) "match files of the form `test_*.jl`, but exclude `*setup*.jl`" function istestfile(fn) - return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") + return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") end "match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`" function isexamplefile(fn) - return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") + return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") end @time begin - # tests in groups based on folder structure - for testgroup in filter(isdir, readdir(@__DIR__)) - if GROUP == "ALL" || GROUP == uppercase(testgroup) - groupdir = joinpath(@__DIR__, testgroup) - for file in filter(istestfile, readdir(groupdir)) - filename = joinpath(groupdir, file) - @eval @safetestset $file begin - include($filename) + # tests in groups based on folder structure + for testgroup in filter(isdir, readdir(@__DIR__)) + if GROUP == "ALL" || GROUP == uppercase(testgroup) + groupdir = joinpath(@__DIR__, testgroup) + for file in filter(istestfile, readdir(groupdir)) + filename = joinpath(groupdir, file) + @eval @safetestset $file begin + include($filename) + end + end end - end end - end - # single files in top folder - for file in filter(istestfile, readdir(@__DIR__)) - (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion - @eval @safetestset $file begin - include($file) + # single files in top folder + for file in filter(istestfile, readdir(@__DIR__)) + (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion + @eval @safetestset $file begin + include($file) + end end - end - # test examples - examplepath = joinpath(@__DIR__, "..", "examples") - for (root, _, files) in walkdir(examplepath) - contains(chopprefix(root, @__DIR__), "setup") && continue - for file in filter(isexamplefile, files) - filename = joinpath(root, file) - @eval begin - @safetestset $file begin - $(Expr( - :macrocall, - GlobalRef(Suppressor, Symbol("@suppress")), - LineNumberNode(@__LINE__, @__FILE__), - :(include($filename)), - )) + # test examples + examplepath = joinpath(@__DIR__, "..", "examples") + for (root, _, files) in walkdir(examplepath) + contains(chopprefix(root, @__DIR__), "setup") && continue + for file in filter(isexamplefile, files) + filename = joinpath(root, file) + @eval begin + @safetestset $file begin + $( + Expr( + :macrocall, + GlobalRef(Suppressor, Symbol("@suppress")), + LineNumberNode(@__LINE__, @__FILE__), + :(include($filename)), + ) + ) + end + end end - end end - end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 34bfff1..0afead5 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(ITensorNetworksNext) + Aqua.test_all(ITensorNetworksNext) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 59e5e35..0c9d803 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -8,56 +8,56 @@ using NamedGraphs.NamedGraphGenerators: named_grid using Test: @test, @testset @testset "ITensorNetworksNext" begin - @testset "Construct TensorNetwork product state" begin - dims = (3, 3) - g = named_grid(dims) - s = Dict(v => Index(2) for v in vertices(g)) - tn = TensorNetwork(g) do v - return randn(s[v]) + @testset "Construct TensorNetwork product state" begin + dims = (3, 3) + g = named_grid(dims) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + @test nv(tn) == 9 + @test ne(tn) == ne(g) + @test issetequal(vertices(tn), vertices(g)) + @test issetequal(arranged_edges(tn), arranged_edges(g)) + for v in vertices(tn) + @test siteinds(tn, v) == [s[v]] + end + for v1 in vertices(tn) + for v2 in vertices(tn) + v1 == v2 && continue + haslink = !isempty(linkinds(tn, v1 => v2)) + @test haslink == has_edge(tn, v1 => v2) + end + end + for e in edges(tn) + @test isone(length(only(linkinds(tn, e)))) + end + end + @testset "Construct TensorNetwork partition function" begin + dims = (3, 3) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + @test nv(tn) == 9 + @test ne(tn) == ne(g) + @test issetequal(vertices(tn), vertices(g)) + @test issetequal(arranged_edges(tn), arranged_edges(g)) + for v in vertices(tn) + @test isempty(siteinds(tn, v)) + end + for v1 in vertices(tn) + for v2 in vertices(tn) + v1 == v2 && continue + haslink = !isempty(linkinds(tn, v1 => v2)) + @test haslink == has_edge(tn, v1 => v2) + end + end + for e in edges(tn) + @test only(linkinds(tn, e)) == l[e] + end end - @test nv(tn) == 9 - @test ne(tn) == ne(g) - @test issetequal(vertices(tn), vertices(g)) - @test issetequal(arranged_edges(tn), arranged_edges(g)) - for v in vertices(tn) - @test siteinds(tn, v) == [s[v]] - end - for v1 in vertices(tn) - for v2 in vertices(tn) - v1 == v2 && continue - haslink = !isempty(linkinds(tn, v1 => v2)) - @test haslink == has_edge(tn, v1 => v2) - end - end - for e in edges(tn) - @test isone(length(only(linkinds(tn, e)))) - end - end - @testset "Construct TensorNetwork partition function" begin - dims = (3, 3) - g = named_grid(dims) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - @test nv(tn) == 9 - @test ne(tn) == ne(g) - @test issetequal(vertices(tn), vertices(g)) - @test issetequal(arranged_edges(tn), arranged_edges(g)) - for v in vertices(tn) - @test isempty(siteinds(tn, v)) - end - for v1 in vertices(tn) - for v2 in vertices(tn) - v1 == v2 && continue - haslink = !isempty(linkinds(tn, v1 => v2)) - @test haslink == has_edge(tn, v1 => v2) - end - end - for e in edges(tn) - @test only(linkinds(tn, e)) == l[e] - end - end end diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl new file mode 100644 index 0000000..81ee722 --- /dev/null +++ b/test/test_beliefpropagation.jl @@ -0,0 +1,43 @@ +using Dictionaries: Dictionary +using ITensorBase: Index +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, + partitionfunction +using Graphs: edges, vertices +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using Test: @test, @testset + +@testset "BeliefPropagation" begin + + #Chain of tensors + dims = (4, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 +end \ No newline at end of file diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl new file mode 100644 index 0000000..2b7b945 --- /dev/null +++ b/test/test_contract_network.jl @@ -0,0 +1,39 @@ +using Graphs: edges +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using NamedGraphs.NamedGraphGenerators: named_grid +using ITensorBase: Index, ITensor +using ITensorNetworksNext: + TensorNetwork, linkinds, siteinds, contract_network +using TensorOperations: TensorOperations +using Test: @test, @testset + +@testset "contract_network" begin + @testset "Contract Vectors of ITensors" begin + i, j, k = Index(2), Index(2), Index(5) + A = ITensor([1.0 1.0; 0.5 1.0], i, j) + B = ITensor([2.0, 1.0], i) + C = ITensor([5.0, 1.0], j) + D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k) + + ABCD_1 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "leftassociative") + ABCD_2 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "optimal") + + @test ABCD_1 == ABCD_2 + end + + @testset "Contract One Dimensional Network" begin + dims = (4, 4) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + z1 = contract_network(tn; alg = "exact", sequence_alg = "optimal")[] + z2 = contract_network(tn; alg = "exact", sequence_alg = "leftassociative")[] + + @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) + end +end diff --git a/test/test_iterators.jl b/test/test_iterators.jl new file mode 100644 index 0000000..456ebcf --- /dev/null +++ b/test/test_iterators.jl @@ -0,0 +1,161 @@ +using Test: @test, @testset, @test_throws +import ITensorNetworksNext as ITensorNetworks +using .ITensorNetworks: SweepIterator, RegionIterator, islaststep, state, increment!, compute!, eachregion + +module TestIteratorUtils + + import ITensorNetworksNext as ITensorNetworks + using .ITensorNetworks + + struct TestProblem <: ITensorNetworks.AbstractProblem + data::Vector{Int} + end + ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)] + function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) + kwargs = ITensorNetworks.region_kwargs(iter) + push!(ITensorNetworks.problem(iter).data, kwargs.val) + return iter + end + + + mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator + state::Int + max::Int + output::Vector{Int} + end + + ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1 + Base.length(TI::TestIterator) = TI.max + ITensorNetworks.state(TI::TestIterator) = TI.state + function ITensorNetworks.compute!(TI::TestIterator) + push!(TI.output, ITensorNetworks.state(TI)) + return TI + end + + mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator + parent::TestIterator + end + + Base.length(SA::SquareAdapter) = length(SA.parent) + ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent) + ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent) + function ITensorNetworks.compute!(SA::SquareAdapter) + ITensorNetworks.compute!(SA.parent) + return last(SA.parent.output)^2 + end + +end + +@testset "Iterators" begin + + import .TestIteratorUtils + + @testset "`AbstractNetworkIterator` Interface" begin + + @testset "Edge cases" begin + TI = TestIteratorUtils.TestIterator(1, 1, []) + cb = [] + @test islaststep(TI) + for _ in TI + @test islaststep(TI) + push!(cb, state(TI)) + end + @test length(cb) == 1 + @test length(TI.output) == 1 + @test only(cb) == 1 + + prob = TestIteratorUtils.TestProblem([]) + @test_throws BoundsError SweepIterator(prob, 0) + @test_throws BoundsError RegionIterator(prob, [], 1) + end + + TI = TestIteratorUtils.TestIterator(1, 4, []) + + @test !islaststep((TI)) + + # First iterator should compute only + rv, st = iterate(TI) + @test !islaststep((TI)) + @test !st + @test rv === TI + @test length(TI.output) == 1 + @test only(TI.output) == 1 + @test state(TI) == 1 + @test !st + + rv, st = iterate(TI, st) + @test !islaststep((TI)) + @test !st + @test length(TI.output) == 2 + @test state(TI) == 2 + @test TI.output == [1, 2] + + increment!(TI) + @test !islaststep((TI)) + @test state(TI) == 3 + @test length(TI.output) == 2 + @test TI.output == [1, 2] + + compute!(TI) + @test !islaststep((TI)) + @test state(TI) == 3 + @test length(TI.output) == 3 + @test TI.output == [1, 2, 3] + + # Final Step + iterate(TI, false) + @test islaststep((TI)) + @test state(TI) == 4 + @test length(TI.output) == 4 + @test TI.output == [1, 2, 3, 4] + + @test iterate(TI, false) === nothing + + TI = TestIteratorUtils.TestIterator(1, 5, []) + + cb = [] + + for _ in TI + @test length(cb) == length(TI.output) - 1 + @test cb == (TI.output)[1:(end - 1)] + push!(cb, state(TI)) + @test cb == TI.output + end + + @test islaststep((TI)) + @test length(TI.output) == 5 + @test length(cb) == 5 + @test cb == TI.output + + + TI = TestIteratorUtils.TestIterator(1, 5, []) + end + + @testset "Adapters" begin + TI = TestIteratorUtils.TestIterator(1, 5, []) + SA = TestIteratorUtils.SquareAdapter(TI) + + @testset "Generic" begin + + i = 0 + for rv in SA + i += 1 + @test rv isa Int + @test rv == i^2 + @test state(SA) == i + end + + @test islaststep((SA)) + + TI = TestIteratorUtils.TestIterator(1, 5, []) + SA = TestIteratorUtils.SquareAdapter(TI) + + SA_c = collect(SA) + + @test SA_c isa Vector + @test length(SA_c) == 5 + @test SA_c == [1, 4, 9, 16, 25] + + end + end +end diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl new file mode 100644 index 0000000..cc86fdc --- /dev/null +++ b/test/test_lazynameddimsarrays.jl @@ -0,0 +1,122 @@ +using AbstractTrees: AbstractTrees, print_tree, printnode +using Base.Broadcast: materialize +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, LazyNamedDimsArray, + Mul, SymbolicArray, ismul, lazy, substitute, symnameddims +using NamedDimsArrays: NamedDimsArray, @names, dename, dimnames, inds, nameddims, namedoneto +using TermInterface: arguments, arity, children, head, iscall, isexpr, maketerm, operation, + sorted_arguments, sorted_children +using Test: @test, @test_throws, @testset +using WrappedUnions: unwrap + +@testset "LazyNamedDimsArrays" begin + function sprint_namespaced(x) + context = (:module => LazyNamedDimsArrays) + module_prefix = "ITensorNetworksNext.LazyNamedDimsArrays." + return replace(sprint(show, MIME"text/plain"(), x; context), module_prefix => "") + end + @testset "Basics" begin + i, j, k, l = namedoneto.(2, (:i, :j, :k, :l)) + a1 = randn(i, j) + a2 = randn(j, k) + a3 = randn(k, l) + l1, l2, l3 = lazy.((a1, a2, a3)) + for li in (l1, l2, l3) + @test li isa LazyNamedDimsArray + @test unwrap(li) isa NamedDimsArray + @test inds(li) == inds(unwrap(li)) + @test copy(li) == unwrap(li) + @test materialize(li) == unwrap(li) + end + l = l1 * l2 * l3 + @test copy(l) ≈ a1 * a2 * a3 + @test materialize(l) ≈ a1 * a2 * a3 + @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) + @test unwrap(l) isa Mul + @test ismul(unwrap(l)) + @test unwrap(l).arguments == [l1 * l2, l3] + # TermInterface.jl + @test operation(unwrap(l)) ≡ * + @test arguments(unwrap(l)) == [l1 * l2, l3] + end + + @testset "TermInterface" begin + a1 = nameddims(randn(2, 2), (:i, :j)) + a2 = nameddims(randn(2, 2), (:j, :k)) + a3 = nameddims(randn(2, 2), (:k, :l)) + l1, l2, l3 = lazy.((a1, a2, a3)) + + @test_throws ErrorException arguments(l1) + @test_throws ErrorException arity(l1) + @test_throws ErrorException children(l1) + @test_throws ErrorException head(l1) + @test !iscall(l1) + @test !isexpr(l1) + @test_throws ErrorException operation(l1) + @test_throws ErrorException sorted_arguments(l1) + @test_throws ErrorException sorted_children(l1) + @test AbstractTrees.children(l1) ≡ () + @test AbstractTrees.nodevalue(l1) ≡ a1 + @test sprint(show, l1) == sprint(show, a1) + # TODO: Fix this test, it is basically correct but the type parameters + # print in a different way. + # @test sprint_namespaced(l1) == + # replace(sprint_namespaced(a1), "NamedDimsArray" => "LazyNamedDimsArray") + @test sprint(printnode, l1) == "[:i, :j]" + @test sprint(print_tree, l1) == "[:i, :j]\n" + + l = l1 * l2 * l3 + @test arguments(l) == [l1 * l2, l3] + @test arity(l) == 2 + @test children(l) == [l1 * l2, l3] + @test head(l) ≡ * + @test iscall(l) + @test isexpr(l) + @test l == maketerm(LazyNamedDimsArray, *, [l1 * l2, l3], nothing) + @test operation(l) ≡ * + @test sorted_arguments(l) == [l1 * l2, l3] + @test sorted_children(l) == [l1 * l2, l3] + @test AbstractTrees.children(l) == [l1 * l2, l3] + @test AbstractTrees.nodevalue(l) ≡ * + @test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint_namespaced(l) == + "named(Base.OneTo(2), :i)×named(Base.OneTo(2), :l) " * + "LazyNamedDimsArray{Float64, …}:\n(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(print_tree, l) == + "(([:i, :j] * [:j, :k]) * [:k, :l])\n" * + "├─ ([:i, :j] * [:j, :k])\n" * + "│ ├─ [:i, :j]\n│ └─ [:j, :k]\n" * + "└─ [:k, :l]\n" + end + + @testset "symnameddims" begin + a1, a2, a3 = symnameddims.((:a1, :a2, :a3)) + @test a1 isa LazyNamedDimsArray + @test unwrap(a1) isa NamedDimsArray + @test dename(a1) isa SymbolicArray + @test dename(unwrap(a1)) isa SymbolicArray + @test dename(unwrap(a1)) == SymbolicArray(:a1) + @test inds(a1) == () + @test dimnames(a1) == () + + ex = a1 * a2 * a3 + @test copy(ex) == ex + @test arguments(ex) == [a1 * a2, a3] + @test operation(ex) ≡ * + @test sprint(show, ex) == "((a1 * a2) * a3)" + @test sprint_namespaced(ex) == + "0-dimensional LazyNamedDimsArray{Any, …}:\n((a1 * a2) * a3)" + end + + @testset "substitute" begin + s = symnameddims.((:a1, :a2, :a3)) + i = @names i[1:4] + a = (randn(2, 2)[i[1], i[2]], randn(2, 2)[i[2], i[3]], randn(2, 2)[i[3], i[4]]) + l = lazy.(a) + + seq = s[1] * (s[2] * s[3]) + net = substitute(seq, s .=> l) + @test net == l[1] * (l[2] * l[3]) + @test arguments(net) == [l[1], l[2] * l[3]] + end +end