diff --git a/Project.toml b/Project.toml index e2f24b9..5e0b116 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,15 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.15" +version = "0.2.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" +DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -31,7 +33,9 @@ ITensorNetworksNextTensorOperationsExt = "TensorOperations" AbstractTrees = "0.4.5" Adapt = "4.3" BackendSelection = "0.1.6" +Combinatorics = "1" DataGraphs = "0.2.7" +DerivableInterfaces = "0.5.5" DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" diff --git a/docs/Project.toml b/docs/Project.toml index a9141bc..266b345 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" [compat] Documenter = "1" Literate = "2" -ITensorNetworksNext = "0.1" +ITensorNetworksNext = "0.2" diff --git a/examples/Project.toml b/examples/Project.toml index 3c061c3..1e3b0ad 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,4 +2,4 @@ ITensorNetworksNext = "302f2e75-49f0-4526-aef7-d8ba550cb06c" [compat] -ITensorNetworksNext = "0.1" +ITensorNetworksNext = "0.2" diff --git a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl index f3b90bf..dbbbbc8 100644 --- a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl +++ b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl @@ -1,16 +1,27 @@ module ITensorNetworksNextTensorOperationsExt using BackendSelection: @Algorithm_str, Algorithm +using ITensorNetworksNext: ITensorNetworksNext, contraction_order +using ITensorNetworksNext.LazyNamedDimsArrays: symnameddims, substitute using NamedDimsArrays: inds -using ITensorNetworksNext: ITensorNetworksNext, contraction_sequence_to_expr using TensorOperations: TensorOperations, optimaltree -function ITensorNetworksNext.contraction_sequence(::Algorithm"optimal", tn::Vector{<:AbstractArray}) - network = collect.(inds.(tn)) - #Converting dims to Float64 to minimize overflow issues +function contraction_order_to_expr(ord) + return ord isa AbstractVector ? prod(contraction_order_to_expr, ord) : symnameddims(ord) +end + +function ITensorNetworksNext.contraction_order(alg::Algorithm"optimal", tn) + ts = [tn[i] for i in keys(tn)] + network = collect.(inds.(ts)) + # Converting dims to Float64 to minimize overflow issues inds_to_dims = Dict(i => Float64(length(i)) for i in unique(reduce(vcat, network))) - seq, _ = optimaltree(network, inds_to_dims) - return contraction_sequence_to_expr(seq) + order, _ = optimaltree(network, inds_to_dims) + # TODO: Map the integer indices back to the original tensor network vertices. + expr = contraction_order_to_expr(order) + verts = collect(keys(tn)) + sym(i) = symnameddims(verts[i], Tuple(inds(tn[verts[i]]))) + subs = Dict(symnameddims(i) => sym(i) for i in eachindex(verts)) + return substitute(expr, subs) end end diff --git a/src/LazyNamedDimsArrays/LazyNamedDimsArrays.jl b/src/LazyNamedDimsArrays/LazyNamedDimsArrays.jl index f497e0f..daa056e 100644 --- a/src/LazyNamedDimsArrays/LazyNamedDimsArrays.jl +++ b/src/LazyNamedDimsArrays/LazyNamedDimsArrays.jl @@ -8,410 +8,6 @@ include("lazyinterface.jl") include("lazybroadcast.jl") include("lazynameddimsarray.jl") include("symbolicnameddimsarray.jl") - -## using AbstractTrees: AbstractTrees -## using WrappedUnions: @wrapped, unwrap -## using NamedDimsArrays: NamedDimsArrays, AbstractNamedDimsArray, AbstractNamedDimsArrayStyle, -## NamedDimsArray, dename, dimnames, inds -## using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments -## using TypeParameterAccessors: unspecify_type_parameters - -## # Defined to avoid type piracy. -## # TODO: Define a proper hash function -## # in NamedDimsArrays.jl, maybe one that is -## # independent of the order of dimensions. -## function _hash(a::NamedDimsArray, h::UInt64) -## h = hash(:NamedDimsArray, h) -## h = hash(dename(a), h) -## for i in inds(a) -## h = hash(i, h) -## end -## return h -## end -## function _hash(x, h::UInt64) -## return hash(x, h) -## end -## -## # Custom version of `AbstractTrees.printnode` to -## # avoid type piracy when overloading on `AbstractNamedDimsArray`. -## printnode_nameddims(io::IO, x) = AbstractTrees.printnode(io, x) -## function printnode_nameddims(io::IO, a::AbstractNamedDimsArray) -## show(io, collect(dimnames(a))) -## return nothing -## end - -## # Generic lazy functionality. -## function maketerm_lazy(type::Type, head, args, metadata) -## if head ≡ * -## return type(maketerm(Mul, head, args, metadata)) -## else -## return error("Only mul supported right now.") -## end -## end -## function getindex_lazy(a::AbstractArray, I...) -## u = unwrap(a) -## if !iscall(u) -## return u[I...] -## else -## return error("Indexing into expression not supported.") -## end -## end -## function arguments_lazy(a) -## u = unwrap(a) -## if !iscall(u) -## return error("No arguments.") -## elseif ismul(u) -## return arguments(u) -## else -## return error("Variant not supported.") -## end -## end -## function children_lazy(a) -## return arguments(a) -## end -## function head_lazy(a) -## return operation(a) -## end -## function iscall_lazy(a) -## return iscall(unwrap(a)) -## end -## function isexpr_lazy(a) -## return iscall(a) -## end -## function operation_lazy(a) -## u = unwrap(a) -## if !iscall(u) -## return error("No operation.") -## elseif ismul(u) -## return operation(u) -## else -## return error("Variant not supported.") -## end -## end -## function sorted_arguments_lazy(a) -## u = unwrap(a) -## if !iscall(u) -## return error("No arguments.") -## elseif ismul(u) -## return sorted_arguments(u) -## else -## return error("Variant not supported.") -## end -## end -## function sorted_children_lazy(a) -## return sorted_arguments(a) -## end -## ismul_lazy(a) = ismul(unwrap(a)) -## function abstracttrees_children_lazy(a) -## if !iscall(a) -## return () -## else -## return arguments(a) -## end -## end -## function nodevalue_lazy(a) -## if !iscall(a) -## return unwrap(a) -## else -## return operation(a) -## end -## end -## using Base.Broadcast: materialize -## function materialize_lazy(a) -## u = unwrap(a) -## if !iscall(u) -## return u -## elseif ismul(u) -## return mapfoldl(materialize, operation(u), arguments(u)) -## else -## return error("Variant not supported.") -## end -## end -## copy_lazy(a) = materialize(a) -## function equals_lazy(a1, a2) -## u1, u2 = unwrap.((a1, a2)) -## if !iscall(u1) && !iscall(u2) -## return u1 == u2 -## elseif ismul(u1) && ismul(u2) -## return arguments(u1) == arguments(u2) -## else -## return false -## end -## end -## function hash_lazy(a, h::UInt64) -## h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) -## # Use `_hash`, which defines a custom hash for NamedDimsArray. -## return _hash(unwrap(a), h) -## end -## function map_arguments_lazy(f, a) -## u = unwrap(a) -## if !iscall(u) -## return error("No arguments to map.") -## elseif ismul(u) -## return lazy(map_arguments(f, u)) -## else -## return error("Variant not supported.") -## end -## end -## function substitute_lazy(a, substitutions::AbstractDict) -## haskey(substitutions, a) && return substitutions[a] -## !iscall(a) && return a -## return map_arguments(arg -> substitute(arg, substitutions), a) -## end -## function substitute_lazy(a, substitutions) -## return substitute(a, Dict(substitutions)) -## end -## function printnode_lazy(io, a) -## # Use `printnode_nameddims` to avoid type piracy, -## # since it overloads on `AbstractNamedDimsArray`. -## return printnode_nameddims(io, unwrap(a)) -## end -## function show_lazy(io::IO, a) -## if !iscall(a) -## return show(io, unwrap(a)) -## else -## return AbstractTrees.printnode(io, a) -## end -## end -## function show_lazy(io::IO, mime::MIME"text/plain", a) -## summary(io, a) -## println(io, ":") -## if !iscall(a) -## show(io, mime, unwrap(a)) -## return nothing -## else -## show(io, a) -## return nothing -## end -## end -## add_lazy(a1, a2) = error("Not implemented.") -## sub_lazy(a) = error("Not implemented.") -## sub_lazy(a1, a2) = error("Not implemented.") -## function mul_lazy(a) -## u = unwrap(a) -## if !iscall(u) -## return lazy(Mul([a])) -## elseif ismul(u) -## return a -## else -## return error("Variant not supported.") -## end -## end -## # Note that this is nested by default. -## mul_lazy(a1, a2) = lazy(Mul([a1, a2])) -## mul_lazy(a1::Number, a2) = error("Not implemented.") -## mul_lazy(a1, a2::Number) = error("Not implemented.") -## mul_lazy(a1::Number, a2::Number) = a1 * a2 -## div_lazy(a1, a2::Number) = error("Not implemented.") -## -## # NamedDimsArrays.jl interface. -## function inds_lazy(a) -## u = unwrap(a) -## if !iscall(u) -## return inds(u) -## elseif ismul(u) -## return mapreduce(inds, symdiff, arguments(u)) -## else -## return error("Variant not supported.") -## end -## end -## function dename_lazy(a) -## u = unwrap(a) -## if !iscall(u) -## return dename(u) -## else -## return error("Variant not supported.") -## end -## end - -## # Lazy broadcasting. -## struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end -## function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...) -## return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.") -## end -## # Linear operations. -## Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2) = a1 + a2 -## Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2) = a1 - a2 -## Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a) = c * a -## Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number) = a * c -## Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number) = a * b -## Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number) = a / c -## Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) = -a - -## # Generic functionality for Applied types, like `Mul`, `Add`, etc. -## ismul(a) = operation(a) ≡ * -## head_applied(a) = operation(a) -## iscall_applied(a) = true -## isexpr_applied(a) = iscall(a) -## function show_applied(io::IO, a) -## args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(a)) -## print(io, "(", join(args, " $(operation(a)) "), ")") -## return nothing -## end -## sorted_arguments_applied(a) = arguments(a) -## children_applied(a) = arguments(a) -## sorted_children_applied(a) = sorted_arguments(a) -## function maketerm_applied(type, head, args, metadata) -## term = type(args) -## @assert head ≡ operation(term) -## return term -## end -## map_arguments_applied(f, a) = unspecify_type_parameters(typeof(a))(map(f, arguments(a))) -## function hash_applied(a, h::UInt64) -## h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) -## for arg in arguments(a) -## h = hash(arg, h) -## end -## return h -## end -## -## abstract type Applied end -## TermInterface.head(a::Applied) = head_applied(a) -## TermInterface.iscall(a::Applied) = iscall_applied(a) -## TermInterface.isexpr(a::Applied) = isexpr_applied(a) -## Base.show(io::IO, a::Applied) = show_applied(io, a) -## TermInterface.sorted_arguments(a::Applied) = sorted_arguments_applied(a) -## TermInterface.children(a::Applied) = children_applied(a) -## TermInterface.sorted_children(a::Applied) = sorted_children_applied(a) -## function TermInterface.maketerm(type::Type{<:Applied}, head, args, metadata) -## return maketerm_applied(type, head, args, metadata) -## end -## map_arguments(f, a::Applied) = map_arguments_applied(f, a) -## Base.hash(a::Applied, h::UInt64) = hash_applied(a, h) -## -## struct Mul{A} <: Applied -## arguments::Vector{A} -## end -## TermInterface.arguments(m::Mul) = getfield(m, :arguments) -## TermInterface.operation(m::Mul) = * - -## @wrapped struct LazyNamedDimsArray{ -## T, A <: AbstractNamedDimsArray{T}, -## } <: AbstractNamedDimsArray{T, Any} -## union::Union{A, Mul{LazyNamedDimsArray{T, A}}} -## end -## function LazyNamedDimsArray(a::AbstractNamedDimsArray) -## # Use `eltype(typeof(a))` for arrays that have different -## # runtime and compile time eltypes, like `ITensor`. -## return LazyNamedDimsArray{eltype(typeof(a)), typeof(a)}(a) -## end -## function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A} -## return LazyNamedDimsArray{T, A}(a) -## end -## lazy(a::LazyNamedDimsArray) = a -## lazy(a::AbstractNamedDimsArray) = LazyNamedDimsArray(a) -## lazy(a::Mul{<:LazyNamedDimsArray}) = LazyNamedDimsArray(a) -## -## NamedDimsArrays.inds(a::LazyNamedDimsArray) = inds_lazy(a) -## NamedDimsArrays.dename(a::LazyNamedDimsArray) = dename_lazy(a) -## -## # Broadcasting -## function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray}) -## return LazyNamedDimsArrayStyle() -## end -## -## # Derived functionality. -## function TermInterface.maketerm(type::Type{LazyNamedDimsArray}, head, args, metadata) -## return maketerm_lazy(type, head, args, metadata) -## end -## Base.getindex(a::LazyNamedDimsArray, I::Int...) = getindex_lazy(a, I...) -## TermInterface.arguments(a::LazyNamedDimsArray) = arguments_lazy(a) -## TermInterface.children(a::LazyNamedDimsArray) = children_lazy(a) -## TermInterface.head(a::LazyNamedDimsArray) = head_lazy(a) -## TermInterface.iscall(a::LazyNamedDimsArray) = iscall_lazy(a) -## TermInterface.isexpr(a::LazyNamedDimsArray) = isexpr_lazy(a) -## TermInterface.operation(a::LazyNamedDimsArray) = operation_lazy(a) -## TermInterface.sorted_arguments(a::LazyNamedDimsArray) = sorted_arguments_lazy(a) -## AbstractTrees.children(a::LazyNamedDimsArray) = abstracttrees_children_lazy(a) -## TermInterface.sorted_children(a::LazyNamedDimsArray) = sorted_children_lazy(a) -## ismul(a::LazyNamedDimsArray) = ismul_lazy(a) -## AbstractTrees.nodevalue(a::LazyNamedDimsArray) = nodevalue_lazy(a) -## Base.Broadcast.materialize(a::LazyNamedDimsArray) = materialize_lazy(a) -## Base.copy(a::LazyNamedDimsArray) = copy_lazy(a) -## Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = equals_lazy(a1, a2) -## Base.hash(a::LazyNamedDimsArray, h::UInt64) = hash_lazy(a, h) -## map_arguments(f, a::LazyNamedDimsArray) = map_arguments_lazy(f, a) -## substitute(a::LazyNamedDimsArray, substitutions) = substitute_lazy(a, substitutions) -## AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a) -## printnode_nameddims(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a) -## Base.show(io::IO, a::LazyNamedDimsArray) = show_lazy(io, a) -## Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) = show_lazy(io, mime, a) -## Base.:*(a::LazyNamedDimsArray) = mul_lazy(a) -## Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = mul_lazy(a1, a2) -## Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = add_lazy(a1, a2) -## Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = sub_lazy(a1, a2) -## Base.:*(a1::Number, a2::LazyNamedDimsArray) = mul_lazy(a1, a2) -## Base.:*(a1::LazyNamedDimsArray, a2::Number) = mul_lazy(a1, a2) -## Base.:/(a1::LazyNamedDimsArray, a2::Number) = div_lazy(a1, a2) -## Base.:-(a::LazyNamedDimsArray) = sub_lazy(a) - -## struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} -## name::Name -## axes::Axes -## function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} -## N = length(ax) -## return new{T, N, typeof(name), typeof(ax)}(name, ax) -## end -## end -## function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) -## return SymbolicArray{Any}(name, ax) -## end -## function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T} -## return SymbolicArray{T}(name, ax) -## end -## function SymbolicArray(name, ax::AbstractUnitRange...) -## return SymbolicArray{Any}(name, ax) -## end -## symname(a::SymbolicArray) = getfield(a, :name) -## Base.axes(a::SymbolicArray) = getfield(a, :axes) -## Base.size(a::SymbolicArray) = length.(axes(a)) -## function Base.:(==)(a::SymbolicArray, b::SymbolicArray) -## return symname(a) == symname(b) && axes(a) == axes(b) -## end -## function Base.hash(a::SymbolicArray, h::UInt64) -## h = hash(:SymbolicArray, h) -## h = hash(symname(a), h) -## return hash(size(a), h) -## end -## function Base.getindex(a::SymbolicArray{<:Any, N}, I::Vararg{Int, N}) where {N} -## return error("Indexing into SymbolicArray not supported.") -## end -## function Base.setindex!(a::SymbolicArray{<:Any, N}, value, I::Vararg{Int, N}) where {N} -## return error("Indexing into SymbolicArray not supported.") -## end -## function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) -## Base.summary(io, a) -## println(io, ":") -## print(io, repr(symname(a))) -## return nothing -## end -## function Base.show(io::IO, a::SymbolicArray) -## print(io, "SymbolicArray(", symname(a), ", ", size(a), ")") -## return nothing -## end -## using AbstractTrees: AbstractTrees -## function AbstractTrees.printnode(io::IO, a::SymbolicArray) -## print(io, repr(symname(a))) -## return nothing -## end -## const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = -## NamedDimsArray{T, N, Parent, DimNames} -## function symnameddims(name) -## return lazy(NamedDimsArray(SymbolicArray(name), ())) -## end -## function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) -## print(io, symname(dename(a))) -## if ndims(a) > 0 -## print(io, "[", join(dimnames(a), ","), "]") -## end -## return nothing -## end -## printnode_nameddims(io::IO, a::SymbolicNamedDimsArray) = AbstractTrees.printnode(io, a) -## function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) -## return issetequal(inds(a), inds(b)) && dename(a) == dename(b) -## end -## Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) = lazy(a) * lazy(b) -## Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) = lazy(a) * b -## Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) = a * lazy(b) +include("evaluation_order.jl") end diff --git a/src/LazyNamedDimsArrays/applied.jl b/src/LazyNamedDimsArrays/applied.jl index 59ff3d6..e486e71 100644 --- a/src/LazyNamedDimsArrays/applied.jl +++ b/src/LazyNamedDimsArrays/applied.jl @@ -3,7 +3,7 @@ using TermInterface: TermInterface, arguments, iscall, operation using TypeParameterAccessors: unspecify_type_parameters # Generic functionality for Applied types, like `Mul`, `Add`, etc. -ismul(a) = operation(a) ≡ * +ismul(a) = iscall(a) && operation(a) ≡ * head_applied(a) = operation(a) iscall_applied(a) = true isexpr_applied(a) = iscall(a) diff --git a/src/LazyNamedDimsArrays/evaluation_order.jl b/src/LazyNamedDimsArrays/evaluation_order.jl new file mode 100644 index 0000000..bb3cc70 --- /dev/null +++ b/src/LazyNamedDimsArrays/evaluation_order.jl @@ -0,0 +1,112 @@ +using NamedDimsArrays: dename, inds +using TermInterface: arguments, arity, operation + +# The time complexity of evaluating `f(args...)`. +function time_complexity(f, args...) + return error("Not implemented.") +end +# The space complexity of evaluating `f(args...)`. +function space_complexity(f, args...) + return error("Not implemented.") +end +# The space complexity of `args`. +function input_space_complexity(f, args...) + return error("Not implemented.") +end + +using NamedDimsArrays: AbstractNamedDimsArray +function time_complexity( + ::typeof(*), t1::AbstractNamedDimsArray, t2::AbstractNamedDimsArray + ) + return prod(length ∘ dename, (inds(t1) ∪ inds(t2))) +end +function time_complexity( + ::typeof(+), t1::AbstractNamedDimsArray, t2::AbstractNamedDimsArray + ) + @assert issetequal(inds(t1), inds(t2)) + return prod(dename, size(t1)) +end +function time_complexity(::typeof(*), c::Number, t::AbstractNamedDimsArray) + return prod(dename, size(t)) +end +function time_complexity(::typeof(*), t::AbstractNamedDimsArray, c::Number) + return time_complexity(*, c, t) +end + +function evaluation_time_complexity(a) + t = Ref(0) + opwalk(a) do f + return function (args...) + t[] += time_complexity(f, args...) + return f(args...) + end + end + return t[] +end + +# The workspace complexity of evaluating expression. +function evaluation_space_complexity(a) + # TODO: Walk the expression and call `space_complexity` on each node. + return error("Not implemented.") +end +# The complexity of storing the arguments of the expression. +function argument_space_complexity(a) + # TODO: Walk the expression and call `input_space_complexity` on each node. + return error("Not implemented.") +end + +# Flatten a nested expression down to a flat expression, +# removing information about the order of operations. +function flatten_expression(a) + if !iscall(a) + return a + elseif ismul(a) + flattened_arguments = mapreduce(vcat, arguments(a)) do arg + return ismul(arg) ? arguments(arg) : [arg] + end + return lazy(Mul(flattened_arguments)) + else + return error("Variant not supported.") + end +end + +function optimize_evaluation_order(alg, a) + return optimize_evaluation_order_flattened(alg, flatten_expression(a)) +end + +function optimize_evaluation_order_flattened(alg, a) + if !iscall(a) + return a + elseif ismul(a) + return optimize_contraction_order_flattened(alg, a) + else + # TODO: Recurse into other operations, calling `optimize_evaluation_order_flattened`. + return error("Variant not supported.") + end +end + +function optimize_evaluation_order( + a; alg = default_optimize_evaluation_order_alg(a) + ) + return optimize_evaluation_order(alg, a) +end + +using BackendSelection: @Algorithm_str, Algorithm +default_optimize_evaluation_order_alg(a) = Algorithm"eager"() + +function optimize_contraction_order_flattened(alg, a) + return error("`alg = $alg` not supported.") +end + +using Combinatorics: combinations +function optimize_contraction_order_flattened(alg::Algorithm"eager", a) + @assert ismul(a) + arity(a) in (1, 2) && return a + a1, a2 = argmin(combinations(arguments(a), 2)) do (a1, a2) + # Penalize outer product contractions. + isdisjoint(inds(a1), inds(a2)) && return typemax(Int) + return time_complexity(*, a1, a2) + end + contracted_arguments = [filter(∉((a1, a2)), arguments(a)); [a1 * a2]] + return optimize_contraction_order_flattened(alg, lazy(Mul(contracted_arguments))) +end diff --git a/src/LazyNamedDimsArrays/lazyinterface.jl b/src/LazyNamedDimsArrays/lazyinterface.jl index 6603fa4..65eca03 100644 --- a/src/LazyNamedDimsArrays/lazyinterface.jl +++ b/src/LazyNamedDimsArrays/lazyinterface.jl @@ -4,6 +4,22 @@ using WrappedUnions: unwrap lazy(x) = error("Not defined.") +# Walk the expression `ex`, modifying the +# operations by `opmap` and the arguments by `argmap`. +function walk(opmap, argmap, ex) + if !iscall(ex) + return argmap(ex) + else + return mapfoldl((args...) -> walk(opmap, argmap, args...), opmap(operation(ex)), arguments(ex)) + end +end +# Walk the expression `ex`, modifying the +# operations by `opmap`. +opwalk(opmap, a) = walk(opmap, identity, a) +# Walk the expression `ex`, modifying the +# arguments by `argmap`. +argwalk(argmap, a) = walk(identity, argmap, a) + # Generic lazy functionality. function maketerm_lazy(type::Type, head, args, metadata) if head ≡ * @@ -80,17 +96,8 @@ function nodevalue_lazy(a) return operation(a) end end +materialize_lazy(a) = argwalk(unwrap, a) using Base.Broadcast: materialize -function materialize_lazy(a) - u = unwrap(a) - if !iscall(u) - return u - elseif ismul(u) - return mapfoldl(materialize, operation(u), arguments(u)) - else - return error("Variant not supported.") - end -end copy_lazy(a) = materialize(a) function equals_lazy(a1, a2) u1, u2 = unwrap.((a1, a2)) @@ -102,6 +109,16 @@ function equals_lazy(a1, a2) return false end end +function isequal_lazy(a1, a2) + u1, u2 = unwrap.((a1, a2)) + if !iscall(u1) && !iscall(u2) + return isequal(u1, u2) + elseif ismul(u1) && ismul(u2) + return isequal(arguments(u1), arguments(u2)) + else + return false + end +end function hash_lazy(a, h::UInt64) h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) # Use `_hash`, which defines a custom hash for NamedDimsArray. diff --git a/src/LazyNamedDimsArrays/lazynameddimsarray.jl b/src/LazyNamedDimsArrays/lazynameddimsarray.jl index ed92c80..b0ed86a 100644 --- a/src/LazyNamedDimsArrays/lazynameddimsarray.jl +++ b/src/LazyNamedDimsArrays/lazynameddimsarray.jl @@ -6,13 +6,18 @@ using WrappedUnions: @wrapped } <: AbstractNamedDimsArray{T, Any} union::Union{A, Mul{LazyNamedDimsArray{T, A}}} end + +parenttype(::Type{LazyNamedDimsArray{<:Any, A}}) where {A} = A +parenttype(::Type{LazyNamedDimsArray{T}}) where {T} = AbstractNamedDimsArray{T} +parenttype(::Type{LazyNamedDimsArray}) = AbstractNamedDimsArray + function LazyNamedDimsArray(a::AbstractNamedDimsArray) # Use `eltype(typeof(a))` for arrays that have different # runtime and compile time eltypes, like `ITensor`. return LazyNamedDimsArray{eltype(typeof(a)), typeof(a)}(a) end -function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A} - return LazyNamedDimsArray{T, A}(a) +function LazyNamedDimsArray(a::Mul{L}) where {L <: LazyNamedDimsArray} + return LazyNamedDimsArray{eltype(L), parenttype(L)}(a) end lazy(a::LazyNamedDimsArray) = a lazy(a::AbstractNamedDimsArray) = LazyNamedDimsArray(a) @@ -45,6 +50,7 @@ AbstractTrees.nodevalue(a::LazyNamedDimsArray) = nodevalue_lazy(a) Base.Broadcast.materialize(a::LazyNamedDimsArray) = materialize_lazy(a) Base.copy(a::LazyNamedDimsArray) = copy_lazy(a) Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = equals_lazy(a1, a2) +Base.isequal(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = isequal_lazy(a1, a2) Base.hash(a::LazyNamedDimsArray, h::UInt64) = hash_lazy(a, h) map_arguments(f, a::LazyNamedDimsArray) = map_arguments_lazy(f, a) substitute(a::LazyNamedDimsArray, substitutions) = substitute_lazy(a, substitutions) diff --git a/src/LazyNamedDimsArrays/symbolicarray.jl b/src/LazyNamedDimsArrays/symbolicarray.jl index 6953d93..d9e020b 100644 --- a/src/LazyNamedDimsArrays/symbolicarray.jl +++ b/src/LazyNamedDimsArrays/symbolicarray.jl @@ -1,3 +1,4 @@ +# TODO: Allow dynamic/unknown number of dimensions by supporting vector axes. struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} name::Name axes::Axes @@ -9,12 +10,6 @@ end function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) return SymbolicArray{Any}(name, ax) end -function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T} - return SymbolicArray{T}(name, ax) -end -function SymbolicArray(name, ax::AbstractUnitRange...) - return SymbolicArray{Any}(name, ax) -end symname(a::SymbolicArray) = getfield(a, :name) Base.axes(a::SymbolicArray) = getfield(a, :axes) Base.size(a::SymbolicArray) = length.(axes(a)) @@ -32,6 +27,12 @@ end function Base.setindex!(a::SymbolicArray{<:Any, N}, value, I::Vararg{Int, N}) where {N} return error("Indexing into SymbolicArray not supported.") end +using DerivableInterfaces: DerivableInterfaces +DerivableInterfaces.permuteddims(a::SymbolicArray, p) = permutedims(a, p) +function Base.permutedims(a::SymbolicArray, p) + @assert ndims(a) == length(p) && isperm(p) + return SymbolicArray(symname(a), ntuple(i -> axes(a)[p[i]], ndims(a))) +end function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) Base.summary(io, a) println(io, ":") diff --git a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl index 495869a..a215319 100644 --- a/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl +++ b/src/LazyNamedDimsArrays/symbolicnameddimsarray.jl @@ -1,10 +1,11 @@ -using NamedDimsArrays: NamedDimsArray, dename, inds +using NamedDimsArrays: NamedDimsArray, dename, inds, nameddims const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = NamedDimsArray{T, N, Parent, DimNames} -function symnameddims(name) - return lazy(NamedDimsArray(SymbolicArray(name), ())) +function symnameddims(name, dims) + return lazy(nameddims(SymbolicArray(name, dename.(dims)), dims)) end +symnameddims(name) = symnameddims(name, ()) using AbstractTrees: AbstractTrees function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) print(io, symname(dename(a))) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index aed941b..e566752 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -40,6 +40,7 @@ Base.copy(tn::AbstractTensorNetwork) = error("Not implemented") # Iteration Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...) +Base.keys(tn::AbstractTensorNetwork) = vertices(tn) # TODO: This contrasts with the `DataGraphs.AbstractDataGraph` definition, # where it is defined as the `vertextype`. Does that cause problems or should it be changed? diff --git a/src/contract_network.jl b/src/contract_network.jl index 67d69e0..5651515 100644 --- a/src/contract_network.jl +++ b/src/contract_network.jl @@ -1,47 +1,50 @@ using BackendSelection: @Algorithm_str, Algorithm -using ITensorNetworksNext.LazyNamedDimsArrays: substitute, materialize, lazy, +using Base.Broadcast: materialize +using ITensorNetworksNext.LazyNamedDimsArrays: lazy, optimize_evaluation_order, substitute, symnameddims -#Algorithmic defaults -default_sequence_alg(::Algorithm"exact") = "leftassociative" -default_sequence(::Algorithm"exact") = nothing -function set_default_kwargs(alg::Algorithm"exact") - sequence = get(alg, :sequence, nothing) - sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) - return Algorithm("exact"; sequence, sequence_alg) +# This is related to `MatrixAlgebraKit.select_algorithm`. +# TODO: Define this in BackendSelection.jl. +backend_value(::Algorithm{alg}) where {alg} = alg +using BackendSelection: parameters +function merge_parameters(alg::Algorithm; kwargs...) + return Algorithm(backend_value(alg); merge(parameters(alg), kwargs)...) end +to_algorithm(alg::Algorithm; kwargs...) = merge_parameters(alg; kwargs...) +to_algorithm(alg; kwargs...) = Algorithm(alg; kwargs...) -function contraction_sequence_to_expr(seq) - if seq isa AbstractVector - return prod(contraction_sequence_to_expr, seq) - else - return symnameddims(seq) - end -end - -function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray}) - return prod(symnameddims, 1:length(tn)) +# `contract_network` +contract_network(alg::Algorithm, tn) = error("Not implemented.") +function default_kwargs(::typeof(contract_network), tn) + return (; alg = Algorithm"exact"(; order_alg = Algorithm"eager"())) end - -function contraction_sequence(tn::Vector{<:AbstractArray}; sequence_alg = default_sequence_alg(Algorithm("exact"))) - return contraction_sequence(Algorithm(sequence_alg), tn) +function contract_network(tn; alg = default_kwargs(contract_network, tn).alg, kwargs...) + return contract_network(to_algorithm(alg; kwargs...), tn) end -function contract_network(alg::Algorithm"exact", tn::Vector{<:AbstractArray}) - if !isnothing(alg.sequence) - sequence = alg.sequence - else - sequence = contraction_sequence(tn; sequence_alg = alg.sequence_alg) +# `contract_network(::Algorithm"exact", ...)` +function contract_network(alg::Algorithm"exact", tn) + order = @something begin + get(alg, :order, nothing) + contraction_order( + tn; alg = get(alg, :order_alg, default_kwargs(contraction_order, tn).alg) + ) end - - sequence = substitute(sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn))) - return materialize(sequence) + syms_to_ts = Dict(symnameddims(i, Tuple(inds(tn[i]))) => lazy(tn[i]) for i in eachindex(tn)) + tn_expression = substitute(order, syms_to_ts) + return materialize(tn_expression) end -function contract_network(alg::Algorithm"exact", tn::AbstractTensorNetwork) - return contract_network(alg, [tn[v] for v in vertices(tn)]) +# `contraction_order` +function contraction_order end +default_kwargs(::typeof(contraction_order), tn) = (; alg = Algorithm"eager"()) +function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg, kwargs...) + return contraction_order(to_algorithm(alg; kwargs...), tn) end - -function contract_network(tn; alg, kwargs...) - return contract_network(set_default_kwargs(Algorithm(alg; kwargs...)), tn) +function contraction_order(alg::Algorithm"left_associative", tn) + return prod(i -> symnameddims(i, Tuple(inds(tn[i]))), keys(tn)) +end +function contraction_order(alg::Algorithm, tn) + s = contraction_order(tn; alg = Algorithm"left_associative"()) + return optimize_evaluation_order(s; alg) end diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index c7d1479..582eec6 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,9 +1,10 @@ +using Combinatorics: combinations using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: AbstractDictionary, Indices, dictionary using Graphs: AbstractSimpleGraph using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype -using NamedGraphs.GraphsExtensions: arranged_edges, vertextype +using NamedGraphs.GraphsExtensions: add_edges!, arrange_edge, arranged_edges, vertextype function _TensorNetwork end @@ -18,6 +19,10 @@ struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionar return new{V, VD, UG, Tensors}(underlying_graph, tensors) end end +# This assumes the tensor connectivity matches the graph structure. +function _TensorNetwork(graph::AbstractGraph, tensors) + return _TensorNetwork(graph, Dictionary(keys(tensors), values(tensors))) +end DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) @@ -25,38 +30,54 @@ function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) return fieldtype(type, :underlying_graph) end -# Determine the graph structure from the tensors. -function TensorNetwork(t::AbstractDictionary) - g = NamedGraph(eachindex(t)) - for v1 in vertices(g) - for v2 in vertices(g) - if v1 ≠ v2 - if !isdisjoint(dimnames(t[v1]), dimnames(t[v2])) - add_edge!(g, v1 => v2) - end +# For a collection of tensors, return the edges implied by shared indices +# as a list of `edgetype` edges of keys/vertices. +function tensornetwork_edges(edgetype::Type, tensors) + # We need to collect the keys since in the case of `tensors::AbstractDictionary`, + # `keys(tensors)::AbstractIndices`, which is indexed by `keys(tensors)` rather + # than `1:length(keys(tensors))`, which is assumed by `combinations`. + verts = collect(keys(tensors)) + return filter( + !isnothing, map(combinations(verts, 2)) do (v1, v2) + if !isdisjoint(inds(tensors[v1]), inds(tensors[v2])) + return arrange_edge(edgetype(v1, v2)) end + return nothing end - end - return _TensorNetwork(g, t) + ) end -function TensorNetwork(tensors::AbstractDict) - return TensorNetwork(Dictionary(tensors)) +tensornetwork_edges(tensors) = tensornetwork_edges(NamedEdge, tensors) + +function TensorNetwork(f::Base.Callable, graph::AbstractGraph) + tensors = Dictionary(vertices(graph), f.(vertices(graph))) + return TensorNetwork(graph, tensors) +end +function TensorNetwork(graph::AbstractGraph, tensors) + tn = _TensorNetwork(graph, tensors) + fix_links!(tn) + return tn end -function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) - tn = TensorNetwork(tensors) - arranged_edges(tn) ⊆ arranged_edges(graph) || +# Insert trivial links for missing edges, and also check +# the vertices and edges are consistent between the graph and tensors. +function fix_links!(tn::AbstractTensorNetwork) + graph = underlying_graph(tn) + tensors = vertex_data(tn) + @assert issetequal(vertices(graph), keys(tensors)) "Graph vertices and tensor keys must match." + tn_edges = tensornetwork_edges(edgetype(graph), tensors) + tn_edges ⊆ arranged_edges(graph) || error("The edges in the tensors do not match the graph structure.") - for e in setdiff(arranged_edges(graph), arranged_edges(tn)) + for e in setdiff(arranged_edges(graph), tn_edges) insert_trivial_link!(tn, e) end return tn end -function TensorNetwork(graph::AbstractGraph, tensors::AbstractDict) - return TensorNetwork(graph, Dictionary(tensors)) -end -function TensorNetwork(f, graph::AbstractGraph) - return TensorNetwork(graph, Dict(v => f(v) for v in vertices(graph))) + +# Determine the graph structure from the tensors. +function TensorNetwork(tensors) + graph = NamedGraph(keys(tensors)) + add_edges!(graph, tensornetwork_edges(tensors)) + return _TensorNetwork(graph, tensors) end function Base.copy(tn::TensorNetwork) @@ -65,10 +86,9 @@ end TensorNetwork(tn::TensorNetwork) = copy(tn) TensorNetwork{V}(tn::TensorNetwork{V}) where {V} = copy(tn) function TensorNetwork{V}(tn::TensorNetwork) where {V} - g′ = convert_vertextype(V, underlying_graph(tn)) - d = vertex_data(tn) - d′ = dictionary(V(k) => d[k] for k in eachindex(d)) - return TensorNetwork(g′, d′) + g = convert_vertextype(V, underlying_graph(tn)) + d = dictionary(V(k) => tn[k] for k in keys(d)) + return TensorNetwork(g, d) end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn diff --git a/test/Project.toml b/test/Project.toml index 7a8e233..031574e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -23,7 +23,7 @@ DiagonalArrays = "0.3.23" Dictionaries = "0.4.5" Graphs = "1.13.1" ITensorBase = "0.3" -ITensorNetworksNext = "0.1.1" +ITensorNetworksNext = "0.2" NamedDimsArrays = "0.8" NamedGraphs = "0.6.8, 0.7" QuadGK = "2.11.2" diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl index 2b7b945..c9abfdd 100644 --- a/test/test_contract_network.jl +++ b/test/test_contract_network.jl @@ -2,8 +2,7 @@ using Graphs: edges using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using NamedGraphs.NamedGraphGenerators: named_grid using ITensorBase: Index, ITensor -using ITensorNetworksNext: - TensorNetwork, linkinds, siteinds, contract_network +using ITensorNetworksNext: TensorNetwork, linkinds, siteinds, contract_network using TensorOperations: TensorOperations using Test: @test, @testset @@ -15,10 +14,11 @@ using Test: @test, @testset C = ITensor([5.0, 1.0], j) D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k) - ABCD_1 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "leftassociative") - ABCD_2 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "optimal") + ABCD_1 = contract_network([A, B, C, D]; order_alg = "left_associative") + ABCD_2 = contract_network([A, B, C, D]; order_alg = "eager") + ABCD_3 = contract_network([A, B, C, D]; order_alg = "optimal") - @test ABCD_1 == ABCD_2 + @test ABCD_1 == ABCD_2 == ABCD_3 end @testset "Contract One Dimensional Network" begin @@ -31,9 +31,11 @@ using Test: @test, @testset return randn(Tuple(is)) end - z1 = contract_network(tn; alg = "exact", sequence_alg = "optimal")[] - z2 = contract_network(tn; alg = "exact", sequence_alg = "leftassociative")[] + z1 = contract_network(tn; order_alg = "left_associative")[] + z2 = contract_network(tn; order_alg = "eager")[] + z3 = contract_network(tn; order_alg = "optimal")[] @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) + @test abs(z1 - z3) / abs(z1) <= 1.0e3 * eps(Float64) end end diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index cc86fdc..2738d59 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -95,7 +95,7 @@ using WrappedUnions: unwrap @test unwrap(a1) isa NamedDimsArray @test dename(a1) isa SymbolicArray @test dename(unwrap(a1)) isa SymbolicArray - @test dename(unwrap(a1)) == SymbolicArray(:a1) + @test dename(unwrap(a1)) == SymbolicArray(:a1, ()) @test inds(a1) == () @test dimnames(a1) == () diff --git a/test/test_tensornetworkgenerators.jl b/test/test_tensornetworkgenerators.jl index 152e67b..1fa6bb5 100644 --- a/test/test_tensornetworkgenerators.jl +++ b/test/test_tensornetworkgenerators.jl @@ -82,8 +82,7 @@ end @test issetequal(is, inds(tn[v])) @test tn[v] ≠ δ(Tuple(is)) end - # TODO: Use eager contraction sequence finding. - z = contract_network(tn; alg = "exact")[] + z = contract_network(tn)[] f = -log(z) / (β * nv(g)) f_analytic = TestUtils.f_1d_ising(β, 4; periodic) @test f ≈ f_analytic @@ -104,8 +103,7 @@ end @test issetequal(is, inds(tn[v])) @test tn[v] ≠ δ(Tuple(is)) end - # TODO: Use eager contraction sequence finding. - z = contract_network(tn; alg = "exact")[] + z = contract_network(tn)[] f = -log(z) / (β * nv(g)) f_inf = TestUtils.f_2d_ising(β) @test f ≈ f_inf rtol = 1.0e-1