From 638b94bbff1d0241f263a86f5288711ae55a18cd Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 7 Oct 2025 17:32:34 -0400 Subject: [PATCH 1/2] Better printing, equality, symbolic arrays --- Project.toml | 4 +- src/ITensorNetworksNext.jl | 1 + src/lazynameddimsarrays.jl | 122 +++++++++++++++++++++++++++---- src/symbolicarrays.jl | 44 +++++++++++ test/Project.toml | 1 + test/test_lazynameddimsarrays.jl | 44 ++++++++++- 6 files changed, 197 insertions(+), 19 deletions(-) create mode 100644 src/symbolicarrays.jl diff --git a/Project.toml b/Project.toml index b77134e..5cfd81d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.7" +version = "0.1.8" [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" @@ -19,6 +20,7 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] +AbstractTrees = "0.4.5" Adapt = "4.3" BackendSelection = "0.1.6" DataGraphs = "0.2.7" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 35c9e59..fad134d 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,5 +1,6 @@ module ITensorNetworksNext +include("symbolicarrays.jl") include("lazynameddimsarrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl index 3561cb6..36c22e2 100644 --- a/src/lazynameddimsarrays.jl +++ b/src/lazynameddimsarrays.jl @@ -1,14 +1,49 @@ module LazyNamedDimsArrays +using AbstractTrees: AbstractTrees using WrappedUnions: @wrapped, unwrap using NamedDimsArrays: NamedDimsArrays, AbstractNamedDimsArray, AbstractNamedDimsArrayStyle, + NamedDimsArray, dename, + dimnames, inds +using ..SymbolicArrays: SymbolicArrays, SymbolicArray using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments +const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = + NamedDimsArray{T, N, Parent, DimNames} +function symnameddims(name) + return lazy(NamedDimsArray(SymbolicArray(name), ())) +end +function printnode(io::IO, a::SymbolicNamedDimsArray) + print(io, SymbolicArrays.name(dename(a))) + print(io, "[", join(dimnames(a), ","), "]") + return nothing +end +function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) + return issetequal(inds(a), inds(b)) && dename(a) == dename(b) +end +function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) + return lazy(a) * lazy(b) +end +function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) + return lazy(a) * b +end +function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) + return a * lazy(b) +end + +# Custom version of `AbstractTrees.printnode` to +# avoid type piracy when overloading on `AbstractNamedDimsArray`. +printnode(io::IO, x) = AbstractTrees.printnode(io, x) +function printnode(io::IO, a::AbstractNamedDimsArray) + show(io, collect(dimnames(a))) + return nothing +end + struct Mul{A} arguments::Vector{A} end @@ -21,6 +56,13 @@ TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args) TermInterface.operation(m::Mul) = * TermInterface.sorted_arguments(m::Mul) = arguments(m) TermInterface.sorted_children(m::Mul) = sorted_arguments(a) +ismul(x) = false +ismul(m::Mul) = true +function Base.show(io::IO, m::Mul) + args = map(arg -> sprint(printnode, arg), arguments(m)) + print(io, "(", join(args, " $(operation(m)) "), ")") + return nothing +end @wrapped struct LazyNamedDimsArray{ T, A <: AbstractNamedDimsArray{T}, @@ -30,9 +72,9 @@ end function NamedDimsArrays.inds(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return inds(u) - elseif u isa Mul + elseif ismul(u) return mapreduce(inds, symdiff, arguments(u)) else return error("Variant not supported.") @@ -40,10 +82,8 @@ function NamedDimsArrays.inds(a::LazyNamedDimsArray) end function NamedDimsArrays.dename(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return dename(u) - elseif u isa Mul - return dename(materialize(a), inds(a)) else return error("Variant not supported.") end @@ -51,9 +91,9 @@ end function TermInterface.arguments(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return error("No arguments.") - elseif u isa Mul + elseif ismul(u) return arguments(u) else return error("Variant not supported.") @@ -75,14 +115,14 @@ function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata if head ≡ * return LazyNamedDimsArray(maketerm(Mul, head, args, metadata)) else - return error("Only product terms supported right now.") + return error("Only mul supported right now.") end end function TermInterface.operation(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return error("No operation.") - elseif u isa Mul + elseif ismul(u) return operation(u) else return error("Variant not supported.") @@ -90,9 +130,9 @@ function TermInterface.operation(a::LazyNamedDimsArray) end function TermInterface.sorted_arguments(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return error("No arguments.") - elseif u isa Mul + elseif ismul(u) return sorted_arguments(u) else return error("Variant not supported.") @@ -101,13 +141,29 @@ end function TermInterface.sorted_children(a::LazyNamedDimsArray) return sorted_arguments(a) end +ismul(a::LazyNamedDimsArray) = ismul(unwrap(a)) + +function AbstractTrees.children(a::LazyNamedDimsArray) + if !iscall(a) + return () + else + return arguments(a) + end +end +function AbstractTrees.nodevalue(a::LazyNamedDimsArray) + if !iscall(a) + return unwrap(a) + else + return operation(a) + end +end using Base.Broadcast: materialize function Base.Broadcast.materialize(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return u - elseif u isa Mul + elseif ismul(u) return mapfoldl(materialize, operation(u), arguments(u)) else return error("Variant not supported.") @@ -115,11 +171,45 @@ function Base.Broadcast.materialize(a::LazyNamedDimsArray) end Base.copy(a::LazyNamedDimsArray) = materialize(a) +function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) + u1, u2 = unwrap.((a1, a2)) + if !iscall(u1) && !iscall(u2) + return u1 == u2 + elseif ismul(u1) && ismul(u2) + return arguments(u1) == arguments(u2) + else + return false + end +end + +function printnode(io::IO, a::LazyNamedDimsArray) + return printnode(io, unwrap(a)) +end +function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) + return printnode(io, a) +end +function Base.show(io::IO, a::LazyNamedDimsArray) + if !iscall(a) + return show(io, unwrap(a)) + else + return printnode(io, a) + end +end +function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) + if !iscall(a) + @invoke show(io, mime, a::AbstractNamedDimsArray) + return nothing + else + show(io, a) + return nothing + end +end + function Base.:*(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return LazyNamedDimsArray(Mul([lazy(u)])) - elseif u isa Mul + elseif ismul(u) return a else return error("Variant not supported.") diff --git a/src/symbolicarrays.jl b/src/symbolicarrays.jl new file mode 100644 index 0000000..ecfd4b3 --- /dev/null +++ b/src/symbolicarrays.jl @@ -0,0 +1,44 @@ +module SymbolicArrays + +using AbstractTrees: AbstractTrees + +struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} + name::Name + axes::Axes + function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} + N = length(ax) + return new{T, N, typeof(name), typeof(ax)}(name, ax) + end +end +function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) + return SymbolicArray{Any}(name, ax) +end +function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T} + return SymbolicArray{T}(name, ax) +end +function SymbolicArray(name, ax::AbstractUnitRange...) + return SymbolicArray{Any}(name, ax) +end +name(a::SymbolicArray) = getfield(a, :name) +Base.axes(a::SymbolicArray) = getfield(a, :axes) +Base.size(a::SymbolicArray) = length.(axes(a)) +function Base.:(==)(a::SymbolicArray, b::SymbolicArray) + return name(a) == name(b) && axes(a) == axes(b) +end +function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) + Base.summary(io, a) + println(io, ":") + print(io, repr(name(a))) + return nothing +end +function Base.show(io::IO, a::SymbolicArray) + print(io, "SymbolicArray(", name(a), ", ", size(a), ")") + return nothing +end + +function AbstractTrees.printnode(io::IO, a::SymbolicArray) + print(io, repr(name(a))) + return nothing +end + +end diff --git a/test/Project.toml b/test/Project.toml index 9646508..cef6740 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index 4c38c5e..67aae9d 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,6 +1,9 @@ +using AbstractTrees: AbstractTrees, print_tree, printnode using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy -using NamedDimsArrays: NamedDimsArray, inds, nameddims +using ITensorNetworksNext.LazyNamedDimsArrays: + LazyNamedDimsArray, Mul, ismul, lazy, symnameddims +using ITensorNetworksNext.SymbolicArrays: SymbolicArray +using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims using TermInterface: arguments, arity, @@ -33,6 +36,7 @@ using WrappedUnions: unwrap @test materialize(l) ≈ a1 * a2 * a3 @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) @test unwrap(l) isa Mul + @test ismul(unwrap(l)) @test unwrap(l).arguments == [l1 * l2, l3] # TermInterface.jl @test operation(unwrap(l)) ≡ * @@ -54,6 +58,15 @@ using WrappedUnions: unwrap @test_throws ErrorException operation(l1) @test_throws ErrorException sorted_arguments(l1) @test_throws ErrorException sorted_children(l1) + @test AbstractTrees.children(l1) ≡ () + @test AbstractTrees.nodevalue(l1) ≡ a1 + @test sprint(show, l1) == sprint(show, a1) + # TODO: Fix this test, it is basically correct but the type parameters + # print in a different way. + # @test sprint(show, MIME"text/plain"(), l1) == + # replace(sprint(show, MIME"text/plain"(), a1), "NamedDimsArray" => "LazyNamedDimsArray") + @test sprint(printnode, l1) == "[:i, :j]" + @test sprint(print_tree, l1) == "[:i, :j]\n" l = l1 * l2 * l3 @test arguments(l) == [l1 * l2, l3] @@ -66,5 +79,32 @@ using WrappedUnions: unwrap @test operation(l) ≡ * @test sorted_arguments(l) == [l1 * l2, l3] @test sorted_children(l) == [l1 * l2, l3] + @test AbstractTrees.children(l) == [l1 * l2, l3] + @test AbstractTrees.nodevalue(l) ≡ * + @test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(show, MIME"text/plain"(), l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(print_tree, l) == + "(([:i, :j] * [:j, :k]) * [:k, :l])\n├─ ([:i, :j] * [:j, :k])\n│ ├─ [:i, :j]\n│ └─ [:j, :k]\n└─ [:k, :l]\n" + end + + @testset "symnameddims" begin + a = symnameddims(:a) + b = symnameddims(:b) + c = symnameddims(:c) + @test a isa LazyNamedDimsArray + @test unwrap(a) isa NamedDimsArray + @test dename(a) isa SymbolicArray + @test dename(unwrap(a)) isa SymbolicArray + @test dename(unwrap(a)) == SymbolicArray(:a) + @test inds(a) == () + @test dimnames(a) == () + + ex = a * b * c + @test copy(ex) == ex + @test arguments(ex) == [a * b, c] + @test operation(ex) ≡ * + @test sprint(show, ex) == "((a[] * b[]) * c[])" + @test sprint(show, MIME"text/plain"(), ex) == "((a[] * b[]) * c[])" end end From 0fbf838077d5ac8f4e54be459c57bdf963db3b05 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 7 Oct 2025 17:50:51 -0400 Subject: [PATCH 2/2] Reorganize code --- src/ITensorNetworksNext.jl | 1 - src/lazynameddimsarrays.jl | 87 +++++++++++++++++++++++--------- src/symbolicarrays.jl | 44 ---------------- test/Project.toml | 1 + test/test_lazynameddimsarrays.jl | 7 ++- 5 files changed, 67 insertions(+), 73 deletions(-) delete mode 100644 src/symbolicarrays.jl diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index fad134d..35c9e59 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,6 +1,5 @@ module ITensorNetworksNext -include("symbolicarrays.jl") include("lazynameddimsarrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl index 36c22e2..4a038e8 100644 --- a/src/lazynameddimsarrays.jl +++ b/src/lazynameddimsarrays.jl @@ -10,32 +10,8 @@ using NamedDimsArrays: dename, dimnames, inds -using ..SymbolicArrays: SymbolicArrays, SymbolicArray using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments -const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = - NamedDimsArray{T, N, Parent, DimNames} -function symnameddims(name) - return lazy(NamedDimsArray(SymbolicArray(name), ())) -end -function printnode(io::IO, a::SymbolicNamedDimsArray) - print(io, SymbolicArrays.name(dename(a))) - print(io, "[", join(dimnames(a), ","), "]") - return nothing -end -function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) - return issetequal(inds(a), inds(b)) && dename(a) == dename(b) -end -function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) - return lazy(a) * lazy(b) -end -function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) - return lazy(a) * b -end -function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) - return a * lazy(b) -end - # Custom version of `AbstractTrees.printnode` to # avoid type piracy when overloading on `AbstractNamedDimsArray`. printnode(io::IO, x) = AbstractTrees.printnode(io, x) @@ -281,4 +257,67 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) return -a end +struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} + name::Name + axes::Axes + function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} + N = length(ax) + return new{T, N, typeof(name), typeof(ax)}(name, ax) + end +end +function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) + return SymbolicArray{Any}(name, ax) +end +function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T} + return SymbolicArray{T}(name, ax) +end +function SymbolicArray(name, ax::AbstractUnitRange...) + return SymbolicArray{Any}(name, ax) +end +symname(a::SymbolicArray) = getfield(a, :name) +Base.axes(a::SymbolicArray) = getfield(a, :axes) +Base.size(a::SymbolicArray) = length.(axes(a)) +function Base.:(==)(a::SymbolicArray, b::SymbolicArray) + return symname(a) == symname(b) && axes(a) == axes(b) +end +function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) + Base.summary(io, a) + println(io, ":") + print(io, repr(symname(a))) + return nothing +end +function Base.show(io::IO, a::SymbolicArray) + print(io, "SymbolicArray(", symname(a), ", ", size(a), ")") + return nothing +end +using AbstractTrees: AbstractTrees +function AbstractTrees.printnode(io::IO, a::SymbolicArray) + print(io, repr(symname(a))) + return nothing +end +const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = + NamedDimsArray{T, N, Parent, DimNames} +function symnameddims(name) + return lazy(NamedDimsArray(SymbolicArray(name), ())) +end +function printnode(io::IO, a::SymbolicNamedDimsArray) + print(io, symname(dename(a))) + if ndims(a) > 0 + print(io, "[", join(dimnames(a), ","), "]") + end + return nothing +end +function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) + return issetequal(inds(a), inds(b)) && dename(a) == dename(b) +end +function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) + return lazy(a) * lazy(b) +end +function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) + return lazy(a) * b +end +function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) + return a * lazy(b) +end + end diff --git a/src/symbolicarrays.jl b/src/symbolicarrays.jl deleted file mode 100644 index ecfd4b3..0000000 --- a/src/symbolicarrays.jl +++ /dev/null @@ -1,44 +0,0 @@ -module SymbolicArrays - -using AbstractTrees: AbstractTrees - -struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} - name::Name - axes::Axes - function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} - N = length(ax) - return new{T, N, typeof(name), typeof(ax)}(name, ax) - end -end -function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) - return SymbolicArray{Any}(name, ax) -end -function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T} - return SymbolicArray{T}(name, ax) -end -function SymbolicArray(name, ax::AbstractUnitRange...) - return SymbolicArray{Any}(name, ax) -end -name(a::SymbolicArray) = getfield(a, :name) -Base.axes(a::SymbolicArray) = getfield(a, :axes) -Base.size(a::SymbolicArray) = length.(axes(a)) -function Base.:(==)(a::SymbolicArray, b::SymbolicArray) - return name(a) == name(b) && axes(a) == axes(b) -end -function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) - Base.summary(io, a) - println(io, ":") - print(io, repr(name(a))) - return nothing -end -function Base.show(io::IO, a::SymbolicArray) - print(io, "SymbolicArray(", name(a), ", ", size(a), ")") - return nothing -end - -function AbstractTrees.printnode(io::IO, a::SymbolicArray) - print(io, repr(name(a))) - return nothing -end - -end diff --git a/test/Project.toml b/test/Project.toml index cef6740..5a5ce6a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] +AbstractTrees = "0.4.5" Aqua = "0.8.14" Dictionaries = "0.4.5" Graphs = "1.13.1" diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index 67aae9d..40735ec 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,8 +1,7 @@ using AbstractTrees: AbstractTrees, print_tree, printnode using Base.Broadcast: materialize using ITensorNetworksNext.LazyNamedDimsArrays: - LazyNamedDimsArray, Mul, ismul, lazy, symnameddims -using ITensorNetworksNext.SymbolicArrays: SymbolicArray + LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, symnameddims using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims using TermInterface: arguments, @@ -104,7 +103,7 @@ using WrappedUnions: unwrap @test copy(ex) == ex @test arguments(ex) == [a * b, c] @test operation(ex) ≡ * - @test sprint(show, ex) == "((a[] * b[]) * c[])" - @test sprint(show, MIME"text/plain"(), ex) == "((a[] * b[]) * c[])" + @test sprint(show, ex) == "((a * b) * c)" + @test sprint(show, MIME"text/plain"(), ex) == "((a * b) * c)" end end