Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
name = "ITensorNetworksNext"
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.0"
version = "0.1.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"

[compat]
Adapt = "4.3.0"
BackendSelection = "0.1.6"
DataGraphs = "0.2.7"
Dictionaries = "0.4.5"
Graphs = "1.13.1"
LinearAlgebra = "1.10"
MacroTools = "0.5.16"
NamedDimsArrays = "0.7.13"
NamedGraphs = "0.6.9"
SimpleTraits = "0.9.5"
SplitApplyCombine = "1.2.3"
julia = "1.10"
3 changes: 2 additions & 1 deletion src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module ITensorNetworksNext

# Write your package code here.
include("abstracttensornetwork.jl")
include("tensornetwork.jl")

end
279 changes: 279 additions & 0 deletions src/abstracttensornetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
using Adapt: Adapt, adapt, adapt_structure
using BackendSelection: @Algorithm_str, Algorithm
using DataGraphs:
DataGraphs,
AbstractDataGraph,
edge_data,
underlying_graph,
underlying_graph_type,
vertex_data
using Dictionaries: Dictionary
using Graphs:
Graphs,
AbstractEdge,
AbstractGraph,
Graph,
add_edge!,
add_vertex!,
bfs_tree,
center,
dst,
edges,
edgetype,
ne,
neighbors,
nv,
rem_edge!,
src,
vertices
using LinearAlgebra: LinearAlgebra, factorize
using MacroTools: @capture
using NamedDimsArrays: dimnames
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
using NamedGraphs.GraphsExtensions:
⊔, directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype
using SplitApplyCombine: flatten

abstract type AbstractTensorNetwork{V,VD} <: AbstractDataGraph{V,VD,Nothing} end

function Graphs.rem_edge!(tn::AbstractTensorNetwork, e)
rem_edge!(underlying_graph(tn), e)
return tn
end

# TODO: Define a generic fallback for `AbstractDataGraph`?
DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data")

# Graphs.jl overloads
function Graphs.weights(graph::AbstractTensorNetwork)
V = vertextype(graph)
es = Tuple.(edges(graph))
ws = Dictionary{Tuple{V,V},Float64}(es, undef)
for e in edges(graph)
w = log2(dim(commoninds(graph, e)))
ws[(src(e), dst(e))] = w
end
return ws
end

# Copy
Base.copy(tn::AbstractTensorNetwork) = error("Not implemented")

# Iteration
Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...)

# TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition,
# where it is defined as the `vertextype`. Does that cause problems or should it be changed?
Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn))

# Overload if needed
Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false

# Derived interface, may need to be overloaded
function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork})
return underlying_graph_type(data_graph_type(G))
end

# AbstractDataGraphs overloads
function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...)
return error("Not implemented")
end
function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...)
return error("Not implemented")
end

DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented")
function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork)
return NamedGraphs.vertex_positions(underlying_graph(tn))
end
function NamedGraphs.ordered_vertices(tn::AbstractTensorNetwork)
return NamedGraphs.ordered_vertices(underlying_graph(tn))
end

function Adapt.adapt_structure(to, tn::AbstractTensorNetwork)
# TODO: Define and use:
#
# @preserve_graph map_vertex_data(adapt(to), tn)
#
# or just:
#
# @preserve_graph map(adapt(to), tn)
return map_vertex_data_preserve_graph(adapt(to), tn)
end

function linkinds(tn::AbstractTensorNetwork, edge::Pair)
return linkinds(tn, edgetype(tn)(edge))
end
function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge)
return nameddimsindices(tn[src(edge)]) ∩ nameddimsindices(tn[dst(edge)])
end
function linkaxes(tn::AbstractTensorNetwork, edge::Pair)
return linkaxes(tn, edgetype(tn)(edge))
end
function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)])
end
function linknames(tn::AbstractTensorNetwork, edge::Pair)
return linknames(tn, edgetype(tn)(edge))
end
function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge)
return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)])
end

function siteinds(tn::AbstractTensorNetwork, v)
s = nameddimsindices(tn[v])
for v′ in neighbors(tn, v)
s = setdiff(s, nameddimsindices(tn[v′]))
end
return s
end
function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)])
for v′ in neighbors(tn, v)
s = setdiff(s, axes(tn[v′]))
end
return s
end
function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge)
s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)])
for v′ in neighbors(tn, v)
s = setdiff(s, dimnames(tn[v′]))
end
return s
end

function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex)
vertex_data(tn)[vertex] = value
return tn
end

# TODO: Move to `BaseExtensions` module.
function is_setindex!_expr(expr::Expr)
return is_assignment_expr(expr) && is_getindex_expr(first(expr.args))
end
is_setindex!_expr(x) = false
is_getindex_expr(expr::Expr) = (expr.head === :ref)
is_getindex_expr(x) = false
is_assignment_expr(expr::Expr) = (expr.head === :(=))
is_assignment_expr(expr) = false

# TODO: Define this in terms of a function mapping
# preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph
# preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph
# Also allow annotating codeblocks like `@views`.
macro preserve_graph(expr)
if !is_setindex!_expr(expr)
error(
"preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)",
)
end
@capture(expr, array_[indices__] = value_)
return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...)))
end

# Update the graph of the TensorNetwork `tn` to include
# edges that should exist based on the tensor connectivity.
function add_missing_edges!(tn::AbstractTensorNetwork)
foreach(v -> add_missing_edges!(tn, v), vertices(tn))
return tn
end

# Update the graph of the TensorNetwork `tn` to include
# edges that should be incident to the vertex `v`
# based on the tensor connectivity.
function add_missing_edges!(tn::AbstractTensorNetwork, v)
for v′ in vertices(tn)
if v ≠ v′
e = v => v′
if !isempty(linkinds(tn, e))
add_edge!(tn, e)
end
end
end
return tn
end

# Fix the edges of the TensorNetwork `tn` to match
# the tensor connectivity.
function fix_edges!(tn::AbstractTensorNetwork)
foreach(v -> fix_edges!(tn, v), vertices(tn))
return tn
end
# Fix the edges of the TensorNetwork `tn` to match
# the tensor connectivity at vertex `v`.
function fix_edges!(tn::AbstractTensorNetwork, v)
rem_incident_edges!(tn, v)
rem_edges!(tn, incident_edges(tn, v))
add_missing_edges!(tn, v)
return tn
end

# Customization point.
using NamedDimsArrays: AbstractNamedUnitRange, namedunitrange, nametype, randname
function trivial_unitrange(type::Type{<:AbstractUnitRange})
return Base.oneto(one(eltype(type)))
end
function rand_trivial_namedunitrange(
::Type{<:AbstractNamedUnitRange{<:Any,R,N}}
) where {R,N}
return namedunitrange(trivial_unitrange(R), randname(N))
end

dag(x) = x

using NamedDimsArrays: nameddimsindices
function insert_trivial_link!(tn, e)
add_edge!(tn, e)
l = rand_trivial_namedunitrange(eltype(nameddimsindices(tn[src(e)])))
x = similar(tn[src(e)], (l,))
x[1] = 1
@preserve_graph tn[src(e)] = tn[src(e)] * x
@preserve_graph tn[dst(e)] = tn[dst(e)] * dag(x)
return tn
end

function Base.setindex!(tn::AbstractTensorNetwork, value, v)
@preserve_graph tn[v] = value
fix_edges!(tn, v)
return tn
end
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
# Fix ambiguity error.
function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger)
graph[vertices(graph)[vertex]] = value
return graph
end
# Fix ambiguity error.
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge)
return error("No edge data.")
end
# Fix ambiguity error.
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair)
return error("No edge data.")
end
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
# Fix ambiguity error.
function Base.setindex!(
tn::AbstractTensorNetwork,
value,
edge::Pair{<:OrdinalSuffixedInteger,<:OrdinalSuffixedInteger},
)
return error("No edge data.")
end

function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
println(io, "$(typeof(graph)) with $(nv(graph)) vertices:")
show(io, mime, vertices(graph))
println(io, "\n")
println(io, "and $(ne(graph)) edge(s):")
for e in edges(graph)
show(io, mime, e)
println(io)
end
println(io)
println(io, "with vertex data:")
show(io, mime, axes.(vertex_data(graph)))
return nothing
end

Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph)
75 changes: 75 additions & 0 deletions src/tensornetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph
using Dictionaries: AbstractDictionary, Indices, dictionary
using Graphs: AbstractSimpleGraph
using NamedDimsArrays: AbstractNamedDimsArray, dimnames, nameddimsarray
using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype
using NamedGraphs.GraphsExtensions: arranged_edges, vertextype

function _TensorNetwork end

struct TensorNetwork{V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}} <:
AbstractTensorNetwork{V,VD}
underlying_graph::UG
tensors::Tensors
global @inline function _TensorNetwork(
underlying_graph::UG, tensors::Tensors
) where {V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}}
# This assumes the tensor connectivity matches the graph structure.
return new{V,VD,UG,Tensors}(underlying_graph, tensors)
end
end

DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph)
DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors)
function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork})
return fieldtype(type, :underlying_graph)
end

# Determine the graph structure from the tensors.
function TensorNetwork(t::AbstractDictionary)
g = NamedGraph(eachindex(t))
for v1 in vertices(g)
for v2 in vertices(g)
if v1 ≠ v2
if !isdisjoint(dimnames(t[v1]), dimnames(t[v2]))
add_edge!(g, v1 => v2)
end
end
end
end
return _TensorNetwork(g, t)
end
function TensorNetwork(tensors::AbstractDict)
return TensorNetwork(Dictionary(tensors))
end

function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary)
tn = TensorNetwork(tensors)
arranged_edges(tn) ⊆ arranged_edges(graph) ||
error("The edges in the tensors do not match the graph structure.")
for e in setdiff(arranged_edges(graph), arranged_edges(tn))
insert_trivial_link!(tn, e)
end
return tn
end
function TensorNetwork(graph::AbstractGraph, tensors::AbstractDict)
return TensorNetwork(graph, Dictionary(tensors))
end
function TensorNetwork(f, graph::AbstractGraph)
return TensorNetwork(graph, Dict(v => f(v) for v in vertices(graph)))
end

function Base.copy(tn::TensorNetwork)
TensorNetwork(copy(underlying_graph(tn)), copy(vertex_data(tn)))
end
TensorNetwork(tn::TensorNetwork) = copy(tn)
TensorNetwork{V}(tn::TensorNetwork{V}) where {V} = copy(tn)
function TensorNetwork{V}(tn::TensorNetwork) where {V}
g′ = convert_vertextype(V, underlying_graph(tn))
d = vertex_data(tn)
d′ = dictionary(V(k) => d[k] for k in eachindex(d))
return TensorNetwork(g′, d′)
end

NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn
NamedGraphs.convert_vertextype(V::Type, tn::TensorNetwork) = TensorNetwork{V}(tn)
Loading
Loading