|
| 1 | +## using ITensors: IndexSet |
| 2 | +using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, vertex_data |
| 3 | +using Graphs: Graphs, AbstractEdge |
| 4 | +## using ITensors: ITensors, unioninds, uniqueinds |
| 5 | +## using .ITensorsExtensions: ITensorsExtensions, promote_indtype |
| 6 | +using NamedGraphs: NamedGraphs |
| 7 | +using NamedGraphs.GraphsExtensions: incident_edges, rename_vertices |
| 8 | + |
| 9 | +# TODO: Define as `AbstractAxesNetwork`? |
| 10 | +abstract type AbstractIndsNetwork{V,I} <: AbstractDataGraph{V,Vector{I},Vector{I}} end |
| 11 | + |
| 12 | +indtype(is::AbstractIndsNetwork) = indtype(typeof(is)) |
| 13 | +indtype(::Type{<:AbstractIndsNetwork{<:Any,I}}) where {I} = I |
| 14 | + |
| 15 | +# Field access |
| 16 | +data_graph(graph::AbstractIndsNetwork) = not_implemented() |
| 17 | + |
| 18 | +# Overload if needed |
| 19 | +Graphs.is_directed(::Type{<:AbstractIndsNetwork}) = false |
| 20 | + |
| 21 | +# AbstractDataGraphs overloads |
| 22 | +function DataGraphs.vertex_data(graph::AbstractIndsNetwork, args...) |
| 23 | + return vertex_data(data_graph(graph), args...) |
| 24 | +end |
| 25 | +function DataGraphs.edge_data(graph::AbstractIndsNetwork, args...) |
| 26 | + return edge_data(data_graph(graph), args...) |
| 27 | +end |
| 28 | + |
| 29 | +# TODO: Define a generic fallback for `AbstractDataGraph`? |
| 30 | +DataGraphs.edge_data_eltype(::Type{<:AbstractIndsNetwork{V,I}}) where {V,I} = Vector{I} |
| 31 | + |
| 32 | +## TODO: Bring these back. |
| 33 | +## function indsnetwork_getindex(is::AbstractIndsNetwork, index) |
| 34 | +## return get(data_graph(is), index, indtype(is)[]) |
| 35 | +## end |
| 36 | +## |
| 37 | +## function Base.getindex(is::AbstractIndsNetwork, index) |
| 38 | +## return indsnetwork_getindex(is, index) |
| 39 | +## end |
| 40 | +## |
| 41 | +## function Base.getindex(is::AbstractIndsNetwork, index::Pair) |
| 42 | +## return indsnetwork_getindex(is, index) |
| 43 | +## end |
| 44 | +## |
| 45 | +## function Base.getindex(is::AbstractIndsNetwork, index::AbstractEdge) |
| 46 | +## return indsnetwork_getindex(is, index) |
| 47 | +## end |
| 48 | +## |
| 49 | +## function indsnetwork_setindex!(is::AbstractIndsNetwork, value, index) |
| 50 | +## data_graph(is)[index] = value |
| 51 | +## return is |
| 52 | +## end |
| 53 | +## |
| 54 | +## function Base.setindex!(is::AbstractIndsNetwork, value, index) |
| 55 | +## indsnetwork_setindex!(is, value, index) |
| 56 | +## return is |
| 57 | +## end |
| 58 | +## |
| 59 | +## function Base.setindex!(is::AbstractIndsNetwork, value, index::Pair) |
| 60 | +## indsnetwork_setindex!(is, value, index) |
| 61 | +## return is |
| 62 | +## end |
| 63 | +## |
| 64 | +## function Base.setindex!(is::AbstractIndsNetwork, value, index::AbstractEdge) |
| 65 | +## indsnetwork_setindex!(is, value, index) |
| 66 | +## return is |
| 67 | +## end |
| 68 | +## |
| 69 | +## function Base.setindex!(is::AbstractIndsNetwork, value::Index, index) |
| 70 | +## indsnetwork_setindex!(is, value, index) |
| 71 | +## return is |
| 72 | +## end |
| 73 | + |
| 74 | +# |
| 75 | +# Index access |
| 76 | +# |
| 77 | + |
| 78 | +function uniqueinds(is::AbstractIndsNetwork, edge::AbstractEdge) |
| 79 | + # TODO: Replace with `is[v]` once `getindex(::IndsNetwork, ...)` is smarter. |
| 80 | + inds = get(is, src(edge), indtype(is)[]) |
| 81 | + for ei in setdiff(incident_edges(is, src(edge)), [edge]) |
| 82 | + # TODO: Replace with `is[v]` once `getindex(::IndsNetwork, ...)` is smarter. |
| 83 | + inds = unioninds(inds, get(is, ei, indtype(is)[])) |
| 84 | + end |
| 85 | + return inds |
| 86 | +end |
| 87 | + |
| 88 | +function uniqueinds(is::AbstractIndsNetwork, edge::Pair) |
| 89 | + return uniqueinds(is, edgetype(is)(edge)) |
| 90 | +end |
| 91 | + |
| 92 | +function Base.union(is1::AbstractIndsNetwork, is2::AbstractIndsNetwork; kwargs...) |
| 93 | + return IndsNetwork(union(data_graph(is1), data_graph(is2); kwargs...)) |
| 94 | +end |
| 95 | + |
| 96 | +function NamedGraphs.rename_vertices(f::Function, tn::AbstractIndsNetwork) |
| 97 | + return IndsNetwork(rename_vertices(f, data_graph(tn))) |
| 98 | +end |
| 99 | + |
| 100 | +# |
| 101 | +# Convenience functions |
| 102 | +# |
| 103 | + |
| 104 | +## function ITensorsExtensions.promote_indtypeof(is::AbstractIndsNetwork) |
| 105 | +## sitetype = mapreduce(promote_indtype, vertices(is); init=Index{Int}) do v |
| 106 | +## # TODO: Replace with `is[v]` once `getindex(::IndsNetwork, ...)` is smarter. |
| 107 | +## return mapreduce(typeof, promote_indtype, get(is, v, Index[]); init=Index{Int}) |
| 108 | +## end |
| 109 | +## linktype = mapreduce(promote_indtype, edges(is); init=Index{Int}) do e |
| 110 | +## # TODO: Replace with `is[e]` once `getindex(::IndsNetwork, ...)` is smarter. |
| 111 | +## return mapreduce(typeof, promote_indtype, get(is, e, Index[]); init=Index{Int}) |
| 112 | +## end |
| 113 | +## return promote_indtype(sitetype, linktype) |
| 114 | +## end |
| 115 | + |
| 116 | +function union_all_inds(is_in::AbstractIndsNetwork...) |
| 117 | + @assert all(map(ug -> ug == underlying_graph(is_in[1]), underlying_graph.(is_in))) |
| 118 | + is_out = IndsNetwork(underlying_graph(is_in[1])) |
| 119 | + for v in vertices(is_out) |
| 120 | + # TODO: Remove this check. |
| 121 | + if any(isassigned(is, v) for is in is_in) |
| 122 | + # TODO: Change `get` to `getindex`. |
| 123 | + is_out[v] = unioninds([get(is, v, indtype(is)[]) for is in is_in]...) |
| 124 | + end |
| 125 | + end |
| 126 | + for e in edges(is_out) |
| 127 | + # TODO: Remove this check. |
| 128 | + if any(isassigned(is, e) for is in is_in) |
| 129 | + # TODO: Change `get` to `getindex`. |
| 130 | + is_out[e] = unioninds([get(is, e, indtype(is)[]) for is in is_in]...) |
| 131 | + end |
| 132 | + end |
| 133 | + return is_out |
| 134 | +end |
| 135 | + |
| 136 | +function insert_linkinds( |
| 137 | + indsnetwork::AbstractIndsNetwork, |
| 138 | + edges=edges(indsnetwork); |
| 139 | + link_space=trivial_space(indsnetwork), |
| 140 | +) |
| 141 | + indsnetwork = copy(indsnetwork) |
| 142 | + for e in edges |
| 143 | + # TODO: Change to check if it is empty. |
| 144 | + if !isassigned(indsnetwork, e) |
| 145 | + if !isnothing(link_space) |
| 146 | + iₑ = indtype(indsnetwork)(link_space, edge_tag(e)) |
| 147 | + # TODO: Allow setting with just a single axis. |
| 148 | + indsnetwork[e] = [iₑ] |
| 149 | + else |
| 150 | + indsnetwork[e] = [] |
| 151 | + end |
| 152 | + end |
| 153 | + end |
| 154 | + return indsnetwork |
| 155 | +end |
0 commit comments