Skip to content

Commit 6347b7d

Browse files
committed
TensorNetwork type
1 parent eb06f66 commit 6347b7d

File tree

8 files changed

+1756
-3
lines changed

8 files changed

+1756
-3
lines changed

Project.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,29 @@ uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
44
version = "0.1.0"
55

6+
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
9+
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
10+
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
11+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
12+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
14+
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
15+
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
16+
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
17+
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
18+
619
[compat]
20+
Adapt = "4.3.0"
21+
BackendSelection = "0.1.6"
22+
DataGraphs = "0.2.7"
23+
Dictionaries = "0.4.5"
24+
Graphs = "1.13.1"
25+
LinearAlgebra = "1.11.0"
26+
MacroTools = "0.5.16"
27+
NamedDimsArrays = "0.7.13"
28+
NamedGraphs = "0.6.8"
29+
SimpleTraits = "0.9.5"
30+
SplitApplyCombine = "1.2.3"
731
julia = "1.10"

src/ITensorNetworksNext.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module ITensorNetworksNext
22

3-
# Write your package code here.
3+
include("abstractindsnetwork.jl")
4+
include("indsnetwork.jl")
5+
include("abstracttensornetwork.jl")
6+
include("tensornetwork.jl")
47

58
end

src/abstractindsnetwork.jl

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
## using ITensors: IndexSet
2+
using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, vertex_data
3+
using Graphs: Graphs, AbstractEdge
4+
## using ITensors: ITensors, unioninds, uniqueinds
5+
## using .ITensorsExtensions: ITensorsExtensions, promote_indtype
6+
using NamedGraphs: NamedGraphs
7+
using NamedGraphs.GraphsExtensions: incident_edges, rename_vertices
8+
9+
# TODO: Define as `AbstractAxesNetwork`?
10+
abstract type AbstractIndsNetwork{V,I} <: AbstractDataGraph{V,Vector{I},Vector{I}} end
11+
12+
indtype(is::AbstractIndsNetwork) = indtype(typeof(is))
13+
indtype(::Type{<:AbstractIndsNetwork{<:Any,I}}) where {I} = I
14+
15+
# Field access
16+
data_graph(graph::AbstractIndsNetwork) = not_implemented()
17+
18+
# Overload if needed
19+
Graphs.is_directed(::Type{<:AbstractIndsNetwork}) = false
20+
21+
# AbstractDataGraphs overloads
22+
function DataGraphs.vertex_data(graph::AbstractIndsNetwork, args...)
23+
return vertex_data(data_graph(graph), args...)
24+
end
25+
function DataGraphs.edge_data(graph::AbstractIndsNetwork, args...)
26+
return edge_data(data_graph(graph), args...)
27+
end
28+
29+
# TODO: Define a generic fallback for `AbstractDataGraph`?
30+
DataGraphs.edge_data_eltype(::Type{<:AbstractIndsNetwork{V,I}}) where {V,I} = Vector{I}
31+
32+
## TODO: Bring these back.
33+
## function indsnetwork_getindex(is::AbstractIndsNetwork, index)
34+
## return get(data_graph(is), index, indtype(is)[])
35+
## end
36+
##
37+
## function Base.getindex(is::AbstractIndsNetwork, index)
38+
## return indsnetwork_getindex(is, index)
39+
## end
40+
##
41+
## function Base.getindex(is::AbstractIndsNetwork, index::Pair)
42+
## return indsnetwork_getindex(is, index)
43+
## end
44+
##
45+
## function Base.getindex(is::AbstractIndsNetwork, index::AbstractEdge)
46+
## return indsnetwork_getindex(is, index)
47+
## end
48+
##
49+
## function indsnetwork_setindex!(is::AbstractIndsNetwork, value, index)
50+
## data_graph(is)[index] = value
51+
## return is
52+
## end
53+
##
54+
## function Base.setindex!(is::AbstractIndsNetwork, value, index)
55+
## indsnetwork_setindex!(is, value, index)
56+
## return is
57+
## end
58+
##
59+
## function Base.setindex!(is::AbstractIndsNetwork, value, index::Pair)
60+
## indsnetwork_setindex!(is, value, index)
61+
## return is
62+
## end
63+
##
64+
## function Base.setindex!(is::AbstractIndsNetwork, value, index::AbstractEdge)
65+
## indsnetwork_setindex!(is, value, index)
66+
## return is
67+
## end
68+
##
69+
## function Base.setindex!(is::AbstractIndsNetwork, value::Index, index)
70+
## indsnetwork_setindex!(is, value, index)
71+
## return is
72+
## end
73+
74+
#
75+
# Index access
76+
#
77+
78+
function uniqueinds(is::AbstractIndsNetwork, edge::AbstractEdge)
79+
# TODO: Replace with `is[v]` once `getindex(::IndsNetwork, ...)` is smarter.
80+
inds = get(is, src(edge), indtype(is)[])
81+
for ei in setdiff(incident_edges(is, src(edge)), [edge])
82+
# TODO: Replace with `is[v]` once `getindex(::IndsNetwork, ...)` is smarter.
83+
inds = unioninds(inds, get(is, ei, indtype(is)[]))
84+
end
85+
return inds
86+
end
87+
88+
function uniqueinds(is::AbstractIndsNetwork, edge::Pair)
89+
return uniqueinds(is, edgetype(is)(edge))
90+
end
91+
92+
function Base.union(is1::AbstractIndsNetwork, is2::AbstractIndsNetwork; kwargs...)
93+
return IndsNetwork(union(data_graph(is1), data_graph(is2); kwargs...))
94+
end
95+
96+
function NamedGraphs.rename_vertices(f::Function, tn::AbstractIndsNetwork)
97+
return IndsNetwork(rename_vertices(f, data_graph(tn)))
98+
end
99+
100+
#
101+
# Convenience functions
102+
#
103+
104+
## function ITensorsExtensions.promote_indtypeof(is::AbstractIndsNetwork)
105+
## sitetype = mapreduce(promote_indtype, vertices(is); init=Index{Int}) do v
106+
## # TODO: Replace with `is[v]` once `getindex(::IndsNetwork, ...)` is smarter.
107+
## return mapreduce(typeof, promote_indtype, get(is, v, Index[]); init=Index{Int})
108+
## end
109+
## linktype = mapreduce(promote_indtype, edges(is); init=Index{Int}) do e
110+
## # TODO: Replace with `is[e]` once `getindex(::IndsNetwork, ...)` is smarter.
111+
## return mapreduce(typeof, promote_indtype, get(is, e, Index[]); init=Index{Int})
112+
## end
113+
## return promote_indtype(sitetype, linktype)
114+
## end
115+
116+
function union_all_inds(is_in::AbstractIndsNetwork...)
117+
@assert all(map(ug -> ug == underlying_graph(is_in[1]), underlying_graph.(is_in)))
118+
is_out = IndsNetwork(underlying_graph(is_in[1]))
119+
for v in vertices(is_out)
120+
# TODO: Remove this check.
121+
if any(isassigned(is, v) for is in is_in)
122+
# TODO: Change `get` to `getindex`.
123+
is_out[v] = unioninds([get(is, v, indtype(is)[]) for is in is_in]...)
124+
end
125+
end
126+
for e in edges(is_out)
127+
# TODO: Remove this check.
128+
if any(isassigned(is, e) for is in is_in)
129+
# TODO: Change `get` to `getindex`.
130+
is_out[e] = unioninds([get(is, e, indtype(is)[]) for is in is_in]...)
131+
end
132+
end
133+
return is_out
134+
end
135+
136+
function insert_linkinds(
137+
indsnetwork::AbstractIndsNetwork,
138+
edges=edges(indsnetwork);
139+
link_space=trivial_space(indsnetwork),
140+
)
141+
indsnetwork = copy(indsnetwork)
142+
for e in edges
143+
# TODO: Change to check if it is empty.
144+
if !isassigned(indsnetwork, e)
145+
if !isnothing(link_space)
146+
iₑ = indtype(indsnetwork)(link_space, edge_tag(e))
147+
# TODO: Allow setting with just a single axis.
148+
indsnetwork[e] = [iₑ]
149+
else
150+
indsnetwork[e] = []
151+
end
152+
end
153+
end
154+
return indsnetwork
155+
end

0 commit comments

Comments
 (0)