diff --git a/Project.toml b/Project.toml index b77134e..5cfd81d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.7" +version = "0.1.8" [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" @@ -19,6 +20,7 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] +AbstractTrees = "0.4.5" Adapt = "4.3" BackendSelection = "0.1.6" DataGraphs = "0.2.7" diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl index 3561cb6..4a038e8 100644 --- a/src/lazynameddimsarrays.jl +++ b/src/lazynameddimsarrays.jl @@ -1,14 +1,25 @@ module LazyNamedDimsArrays +using AbstractTrees: AbstractTrees using WrappedUnions: @wrapped, unwrap using NamedDimsArrays: NamedDimsArrays, AbstractNamedDimsArray, AbstractNamedDimsArrayStyle, + NamedDimsArray, dename, + dimnames, inds using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments +# Custom version of `AbstractTrees.printnode` to +# avoid type piracy when overloading on `AbstractNamedDimsArray`. +printnode(io::IO, x) = AbstractTrees.printnode(io, x) +function printnode(io::IO, a::AbstractNamedDimsArray) + show(io, collect(dimnames(a))) + return nothing +end + struct Mul{A} arguments::Vector{A} end @@ -21,6 +32,13 @@ TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args) TermInterface.operation(m::Mul) = * TermInterface.sorted_arguments(m::Mul) = arguments(m) TermInterface.sorted_children(m::Mul) = sorted_arguments(a) +ismul(x) = false +ismul(m::Mul) = true +function Base.show(io::IO, m::Mul) + args = map(arg -> sprint(printnode, arg), arguments(m)) + print(io, "(", join(args, " $(operation(m)) "), ")") + return nothing +end @wrapped struct LazyNamedDimsArray{ T, A <: AbstractNamedDimsArray{T}, @@ -30,9 +48,9 @@ end function NamedDimsArrays.inds(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return inds(u) - elseif u isa Mul + elseif ismul(u) return mapreduce(inds, symdiff, arguments(u)) else return error("Variant not supported.") @@ -40,10 +58,8 @@ function NamedDimsArrays.inds(a::LazyNamedDimsArray) end function NamedDimsArrays.dename(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return dename(u) - elseif u isa Mul - return dename(materialize(a), inds(a)) else return error("Variant not supported.") end @@ -51,9 +67,9 @@ end function TermInterface.arguments(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return error("No arguments.") - elseif u isa Mul + elseif ismul(u) return arguments(u) else return error("Variant not supported.") @@ -75,14 +91,14 @@ function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata if head ≡ * return LazyNamedDimsArray(maketerm(Mul, head, args, metadata)) else - return error("Only product terms supported right now.") + return error("Only mul supported right now.") end end function TermInterface.operation(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return error("No operation.") - elseif u isa Mul + elseif ismul(u) return operation(u) else return error("Variant not supported.") @@ -90,9 +106,9 @@ function TermInterface.operation(a::LazyNamedDimsArray) end function TermInterface.sorted_arguments(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return error("No arguments.") - elseif u isa Mul + elseif ismul(u) return sorted_arguments(u) else return error("Variant not supported.") @@ -101,13 +117,29 @@ end function TermInterface.sorted_children(a::LazyNamedDimsArray) return sorted_arguments(a) end +ismul(a::LazyNamedDimsArray) = ismul(unwrap(a)) + +function AbstractTrees.children(a::LazyNamedDimsArray) + if !iscall(a) + return () + else + return arguments(a) + end +end +function AbstractTrees.nodevalue(a::LazyNamedDimsArray) + if !iscall(a) + return unwrap(a) + else + return operation(a) + end +end using Base.Broadcast: materialize function Base.Broadcast.materialize(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return u - elseif u isa Mul + elseif ismul(u) return mapfoldl(materialize, operation(u), arguments(u)) else return error("Variant not supported.") @@ -115,11 +147,45 @@ function Base.Broadcast.materialize(a::LazyNamedDimsArray) end Base.copy(a::LazyNamedDimsArray) = materialize(a) +function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) + u1, u2 = unwrap.((a1, a2)) + if !iscall(u1) && !iscall(u2) + return u1 == u2 + elseif ismul(u1) && ismul(u2) + return arguments(u1) == arguments(u2) + else + return false + end +end + +function printnode(io::IO, a::LazyNamedDimsArray) + return printnode(io, unwrap(a)) +end +function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) + return printnode(io, a) +end +function Base.show(io::IO, a::LazyNamedDimsArray) + if !iscall(a) + return show(io, unwrap(a)) + else + return printnode(io, a) + end +end +function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) + if !iscall(a) + @invoke show(io, mime, a::AbstractNamedDimsArray) + return nothing + else + show(io, a) + return nothing + end +end + function Base.:*(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return LazyNamedDimsArray(Mul([lazy(u)])) - elseif u isa Mul + elseif ismul(u) return a else return error("Variant not supported.") @@ -191,4 +257,67 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) return -a end +struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} + name::Name + axes::Axes + function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} + N = length(ax) + return new{T, N, typeof(name), typeof(ax)}(name, ax) + end +end +function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) + return SymbolicArray{Any}(name, ax) +end +function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T} + return SymbolicArray{T}(name, ax) +end +function SymbolicArray(name, ax::AbstractUnitRange...) + return SymbolicArray{Any}(name, ax) +end +symname(a::SymbolicArray) = getfield(a, :name) +Base.axes(a::SymbolicArray) = getfield(a, :axes) +Base.size(a::SymbolicArray) = length.(axes(a)) +function Base.:(==)(a::SymbolicArray, b::SymbolicArray) + return symname(a) == symname(b) && axes(a) == axes(b) +end +function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) + Base.summary(io, a) + println(io, ":") + print(io, repr(symname(a))) + return nothing +end +function Base.show(io::IO, a::SymbolicArray) + print(io, "SymbolicArray(", symname(a), ", ", size(a), ")") + return nothing +end +using AbstractTrees: AbstractTrees +function AbstractTrees.printnode(io::IO, a::SymbolicArray) + print(io, repr(symname(a))) + return nothing +end +const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = + NamedDimsArray{T, N, Parent, DimNames} +function symnameddims(name) + return lazy(NamedDimsArray(SymbolicArray(name), ())) +end +function printnode(io::IO, a::SymbolicNamedDimsArray) + print(io, symname(dename(a))) + if ndims(a) > 0 + print(io, "[", join(dimnames(a), ","), "]") + end + return nothing +end +function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) + return issetequal(inds(a), inds(b)) && dename(a) == dename(b) +end +function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) + return lazy(a) * lazy(b) +end +function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) + return lazy(a) * b +end +function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) + return a * lazy(b) +end + end diff --git a/test/Project.toml b/test/Project.toml index 9646508..5a5ce6a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -13,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] +AbstractTrees = "0.4.5" Aqua = "0.8.14" Dictionaries = "0.4.5" Graphs = "1.13.1" diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index 4c38c5e..40735ec 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,6 +1,8 @@ +using AbstractTrees: AbstractTrees, print_tree, printnode using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy -using NamedDimsArrays: NamedDimsArray, inds, nameddims +using ITensorNetworksNext.LazyNamedDimsArrays: + LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, symnameddims +using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims using TermInterface: arguments, arity, @@ -33,6 +35,7 @@ using WrappedUnions: unwrap @test materialize(l) ≈ a1 * a2 * a3 @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) @test unwrap(l) isa Mul + @test ismul(unwrap(l)) @test unwrap(l).arguments == [l1 * l2, l3] # TermInterface.jl @test operation(unwrap(l)) ≡ * @@ -54,6 +57,15 @@ using WrappedUnions: unwrap @test_throws ErrorException operation(l1) @test_throws ErrorException sorted_arguments(l1) @test_throws ErrorException sorted_children(l1) + @test AbstractTrees.children(l1) ≡ () + @test AbstractTrees.nodevalue(l1) ≡ a1 + @test sprint(show, l1) == sprint(show, a1) + # TODO: Fix this test, it is basically correct but the type parameters + # print in a different way. + # @test sprint(show, MIME"text/plain"(), l1) == + # replace(sprint(show, MIME"text/plain"(), a1), "NamedDimsArray" => "LazyNamedDimsArray") + @test sprint(printnode, l1) == "[:i, :j]" + @test sprint(print_tree, l1) == "[:i, :j]\n" l = l1 * l2 * l3 @test arguments(l) == [l1 * l2, l3] @@ -66,5 +78,32 @@ using WrappedUnions: unwrap @test operation(l) ≡ * @test sorted_arguments(l) == [l1 * l2, l3] @test sorted_children(l) == [l1 * l2, l3] + @test AbstractTrees.children(l) == [l1 * l2, l3] + @test AbstractTrees.nodevalue(l) ≡ * + @test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(show, MIME"text/plain"(), l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(print_tree, l) == + "(([:i, :j] * [:j, :k]) * [:k, :l])\n├─ ([:i, :j] * [:j, :k])\n│ ├─ [:i, :j]\n│ └─ [:j, :k]\n└─ [:k, :l]\n" + end + + @testset "symnameddims" begin + a = symnameddims(:a) + b = symnameddims(:b) + c = symnameddims(:c) + @test a isa LazyNamedDimsArray + @test unwrap(a) isa NamedDimsArray + @test dename(a) isa SymbolicArray + @test dename(unwrap(a)) isa SymbolicArray + @test dename(unwrap(a)) == SymbolicArray(:a) + @test inds(a) == () + @test dimnames(a) == () + + ex = a * b * c + @test copy(ex) == ex + @test arguments(ex) == [a * b, c] + @test operation(ex) ≡ * + @test sprint(show, ex) == "((a * b) * c)" + @test sprint(show, MIME"text/plain"(), ex) == "((a * b) * c)" end end