diff --git a/Project.toml b/Project.toml index 5cfd81d..b944769 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.8" +version = "0.1.9" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -17,6 +17,7 @@ NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" +TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] @@ -33,5 +34,6 @@ NamedGraphs = "0.6.9, 0.7" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" TermInterface = "2" +TypeParameterAccessors = "0.4.4" WrappedUnions = "0.3" julia = "1.10" diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl index 4a038e8..e1b4b27 100644 --- a/src/lazynameddimsarrays.jl +++ b/src/lazynameddimsarrays.jl @@ -11,61 +11,55 @@ using NamedDimsArrays: dimnames, inds using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments +using TypeParameterAccessors: unspecify_type_parameters -# Custom version of `AbstractTrees.printnode` to -# avoid type piracy when overloading on `AbstractNamedDimsArray`. -printnode(io::IO, x) = AbstractTrees.printnode(io, x) -function printnode(io::IO, a::AbstractNamedDimsArray) - show(io, collect(dimnames(a))) - return nothing -end +lazy(x) = error("Not defined.") -struct Mul{A} - arguments::Vector{A} +generic_map(f, v) = map(f, v) +generic_map(f, v::AbstractDict) = Dict(eachindex(v) .=> map(f, values(v))) +generic_map(f, v::AbstractSet) = Set([f(x) for x in v]) + +# Defined to avoid type piracy. +# TODO: Define a proper hash function +# in NamedDimsArrays.jl, maybe one that is +# independent of the order of dimensions. +function _hash(a::NamedDimsArray, h::UInt64) + h = hash(:NamedDimsArray, h) + h = hash(dename(a), h) + for i in inds(a) + h = hash(i, h) + end + return h end -TermInterface.arguments(m::Mul) = getfield(m, :arguments) -TermInterface.children(m::Mul) = arguments(m) -TermInterface.head(m::Mul) = operation(m) -TermInterface.iscall(m::Mul) = true -TermInterface.isexpr(m::Mul) = iscall(m) -TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args) -TermInterface.operation(m::Mul) = * -TermInterface.sorted_arguments(m::Mul) = arguments(m) -TermInterface.sorted_children(m::Mul) = sorted_arguments(a) -ismul(x) = false -ismul(m::Mul) = true -function Base.show(io::IO, m::Mul) - args = map(arg -> sprint(printnode, arg), arguments(m)) - print(io, "(", join(args, " $(operation(m)) "), ")") - return nothing +function _hash(x, h::UInt64) + return hash(x, h) end -@wrapped struct LazyNamedDimsArray{ - T, A <: AbstractNamedDimsArray{T}, - } <: AbstractNamedDimsArray{T, Any} - union::Union{A, Mul{LazyNamedDimsArray{T, A}}} +# Custom version of `AbstractTrees.printnode` to +# avoid type piracy when overloading on `AbstractNamedDimsArray`. +printnode_nameddims(io::IO, x) = AbstractTrees.printnode(io, x) +function printnode_nameddims(io::IO, a::AbstractNamedDimsArray) + show(io, collect(dimnames(a))) + return nothing end -function NamedDimsArrays.inds(a::LazyNamedDimsArray) - u = unwrap(a) - if !iscall(u) - return inds(u) - elseif ismul(u) - return mapreduce(inds, symdiff, arguments(u)) +# Generic lazy functionality. +function maketerm_lazy(type::Type, head, args, metadata) + if head ≡ * + return type(maketerm(Mul, head, args, metadata)) else - return error("Variant not supported.") + return error("Only mul supported right now.") end end -function NamedDimsArrays.dename(a::LazyNamedDimsArray) +function getindex_lazy(a::AbstractArray, I...) u = unwrap(a) if !iscall(u) - return dename(u) + return u[I...] else - return error("Variant not supported.") + return error("Indexing into expression not supported.") end end - -function TermInterface.arguments(a::LazyNamedDimsArray) +function arguments_lazy(a) u = unwrap(a) if !iscall(u) return error("No arguments.") @@ -75,26 +69,19 @@ function TermInterface.arguments(a::LazyNamedDimsArray) return error("Variant not supported.") end end -function TermInterface.children(a::LazyNamedDimsArray) +function children_lazy(a) return arguments(a) end -function TermInterface.head(a::LazyNamedDimsArray) +function head_lazy(a) return operation(a) end -function TermInterface.iscall(a::LazyNamedDimsArray) +function iscall_lazy(a) return iscall(unwrap(a)) end -function TermInterface.isexpr(a::LazyNamedDimsArray) +function isexpr_lazy(a) return iscall(a) end -function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata) - if head ≡ * - return LazyNamedDimsArray(maketerm(Mul, head, args, metadata)) - else - return error("Only mul supported right now.") - end -end -function TermInterface.operation(a::LazyNamedDimsArray) +function operation_lazy(a) u = unwrap(a) if !iscall(u) return error("No operation.") @@ -104,7 +91,7 @@ function TermInterface.operation(a::LazyNamedDimsArray) return error("Variant not supported.") end end -function TermInterface.sorted_arguments(a::LazyNamedDimsArray) +function sorted_arguments_lazy(a) u = unwrap(a) if !iscall(u) return error("No arguments.") @@ -114,28 +101,26 @@ function TermInterface.sorted_arguments(a::LazyNamedDimsArray) return error("Variant not supported.") end end -function TermInterface.sorted_children(a::LazyNamedDimsArray) +function sorted_children_lazy(a) return sorted_arguments(a) end -ismul(a::LazyNamedDimsArray) = ismul(unwrap(a)) - -function AbstractTrees.children(a::LazyNamedDimsArray) +ismul_lazy(a) = ismul(unwrap(a)) +function abstracttrees_children_lazy(a) if !iscall(a) return () else return arguments(a) end end -function AbstractTrees.nodevalue(a::LazyNamedDimsArray) +function nodevalue_lazy(a) if !iscall(a) return unwrap(a) else return operation(a) end end - using Base.Broadcast: materialize -function Base.Broadcast.materialize(a::LazyNamedDimsArray) +function materialize_lazy(a) u = unwrap(a) if !iscall(u) return u @@ -145,9 +130,8 @@ function Base.Broadcast.materialize(a::LazyNamedDimsArray) return error("Variant not supported.") end end -Base.copy(a::LazyNamedDimsArray) = materialize(a) - -function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) +copy_lazy(a) = materialize(a) +function equals_lazy(a1, a2) u1, u2 = unwrap.((a1, a2)) if !iscall(u1) && !iscall(u2) return u1 == u2 @@ -157,105 +141,210 @@ function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) return false end end - -function printnode(io::IO, a::LazyNamedDimsArray) - return printnode(io, unwrap(a)) +function hash_lazy(a, h::UInt64) + h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) + # Use `_hash`, which defines a custom hash for NamedDimsArray. + return _hash(unwrap(a), h) end -function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) - return printnode(io, a) +function map_arguments_lazy(f, a) + u = unwrap(a) + if !iscall(u) + return error("No arguments to map.") + elseif ismul(u) + return lazy(map_arguments(f, u)) + else + return error("Variant not supported.") + end +end +function substitute_lazy(a, substitutions::AbstractDict) + haskey(substitutions, a) && return substitutions[a] + !iscall(a) && return a + return map_arguments(arg -> substitute(arg, substitutions), a) end -function Base.show(io::IO, a::LazyNamedDimsArray) +function substitute_lazy(a, substitutions) + return substitute(a, Dict(substitutions)) +end +function printnode_lazy(io, a) + # Use `printnode_nameddims` to avoid type piracy, + # since it overloads on `AbstractNamedDimsArray`. + return printnode_nameddims(io, unwrap(a)) +end +function show_lazy(io::IO, a) if !iscall(a) return show(io, unwrap(a)) else - return printnode(io, a) + return AbstractTrees.printnode(io, a) end end -function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) +function show_lazy(io::IO, mime::MIME"text/plain", a) + summary(io, a) + println(io, ":") if !iscall(a) - @invoke show(io, mime, a::AbstractNamedDimsArray) + show(io, mime, unwrap(a)) return nothing else show(io, a) return nothing end end - -function Base.:*(a::LazyNamedDimsArray) +add_lazy(a1, a2) = error("Not implemented.") +sub_lazy(a) = error("Not implemented.") +sub_lazy(a1, a2) = error("Not implemented.") +function mul_lazy(a) u = unwrap(a) if !iscall(u) - return LazyNamedDimsArray(Mul([lazy(u)])) + return lazy(Mul([a])) elseif ismul(u) return a else return error("Variant not supported.") end end +# Note that this is nested by default. +mul_lazy(a1, a2) = lazy(Mul([a1, a2])) +mul_lazy(a1::Number, a2) = error("Not implemented.") +mul_lazy(a1, a2::Number) = error("Not implemented.") +mul_lazy(a1::Number, a2::Number) = a1 * a2 +div_lazy(a1, a2::Number) = error("Not implemented.") -function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) - # Nested by default. - return LazyNamedDimsArray(Mul([a1, a2])) -end -function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) - return error("Not implemented.") +# NamedDimsArrays.jl interface. +function inds_lazy(a) + u = unwrap(a) + if !iscall(u) + return inds(u) + elseif ismul(u) + return mapreduce(inds, symdiff, arguments(u)) + else + return error("Variant not supported.") + end end -function Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) - return error("Not implemented.") +function dename_lazy(a) + u = unwrap(a) + if !iscall(u) + return dename(u) + else + return error("Variant not supported.") + end end -function Base.:*(c::Number, a::LazyNamedDimsArray) - return error("Not implemented.") + +# Lazy broadcasting. +struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...) + return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.") end -function Base.:*(a::LazyNamedDimsArray, c::Number) - return error("Not implemented.") +# Linear operations. +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2) = a1 + a2 +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2) = a1 - a2 +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a) = c * a +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number) = a * c +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number) = a * b +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number) = a / c +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) = -a + +# Generic functionality for Applied types, like `Mul`, `Add`, etc. +ismul(a) = operation(a) ≡ * +head_applied(a) = operation(a) +iscall_applied(a) = true +isexpr_applied(a) = iscall(a) +function show_applied(io::IO, a) + args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(a)) + print(io, "(", join(args, " $(operation(a)) "), ")") + return nothing end -function Base.:/(a::LazyNamedDimsArray, c::Number) - return error("Not implemented.") +sorted_arguments_applied(a) = arguments(a) +children_applied(a) = arguments(a) +sorted_children_applied(a) = sorted_arguments(a) +function maketerm_applied(type, head, args, metadata) + term = type(args) + @assert head ≡ operation(term) + return term +end +map_arguments_applied(f, a) = unspecify_type_parameters(typeof(a))(map(f, arguments(a))) +function hash_applied(a, h::UInt64) + h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) + for arg in arguments(a) + h = hash(arg, h) + end + return h end -function Base.:-(a::LazyNamedDimsArray) - return error("Not implemented.") + +abstract type Applied end +TermInterface.head(a::Applied) = head_applied(a) +TermInterface.iscall(a::Applied) = iscall_applied(a) +TermInterface.isexpr(a::Applied) = isexpr_applied(a) +Base.show(io::IO, a::Applied) = show_applied(io, a) +TermInterface.sorted_arguments(a::Applied) = sorted_arguments_applied(a) +TermInterface.children(a::Applied) = children_applied(a) +TermInterface.sorted_children(a::Applied) = sorted_children_applied(a) +function TermInterface.maketerm(type::Type{<:Applied}, head, args, metadata) + return maketerm_applied(type, head, args, metadata) +end +map_arguments(f, a::Applied) = map_arguments_applied(f, a) +Base.hash(a::Applied, h::UInt64) = hash_applied(a, h) + +struct Mul{A} <: Applied + arguments::Vector{A} end +TermInterface.arguments(m::Mul) = getfield(m, :arguments) +TermInterface.operation(m::Mul) = * +@wrapped struct LazyNamedDimsArray{ + T, A <: AbstractNamedDimsArray{T}, + } <: AbstractNamedDimsArray{T, Any} + union::Union{A, Mul{LazyNamedDimsArray{T, A}}} +end function LazyNamedDimsArray(a::AbstractNamedDimsArray) return LazyNamedDimsArray{eltype(a), typeof(a)}(a) end function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A} return LazyNamedDimsArray{T, A}(a) end -function lazy(a::AbstractNamedDimsArray) - return LazyNamedDimsArray(a) -end +lazy(a::LazyNamedDimsArray) = a +lazy(a::AbstractNamedDimsArray) = LazyNamedDimsArray(a) +lazy(a::Mul{<:LazyNamedDimsArray}) = LazyNamedDimsArray(a) + +NamedDimsArrays.inds(a::LazyNamedDimsArray) = inds_lazy(a) +NamedDimsArrays.dename(a::LazyNamedDimsArray) = dename_lazy(a) # Broadcasting -struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray}) return LazyNamedDimsArrayStyle() end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...) - return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.") -end -# Linear operations. -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2) - return a1 + a2 -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2) - return a1 - a2 -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a) - return c * a -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number) - return a * c -end -# Fix ambiguity error. -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number) - return a * b -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number) - return a / c -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) - return -a -end + +# Derived functionality. +function TermInterface.maketerm(type::Type{LazyNamedDimsArray}, head, args, metadata) + return maketerm_lazy(type, head, args, metadata) +end +Base.getindex(a::LazyNamedDimsArray, I::Int...) = getindex_lazy(a, I...) +TermInterface.arguments(a::LazyNamedDimsArray) = arguments_lazy(a) +TermInterface.children(a::LazyNamedDimsArray) = children_lazy(a) +TermInterface.head(a::LazyNamedDimsArray) = head_lazy(a) +TermInterface.iscall(a::LazyNamedDimsArray) = iscall_lazy(a) +TermInterface.isexpr(a::LazyNamedDimsArray) = isexpr_lazy(a) +TermInterface.operation(a::LazyNamedDimsArray) = operation_lazy(a) +TermInterface.sorted_arguments(a::LazyNamedDimsArray) = sorted_arguments_lazy(a) +AbstractTrees.children(a::LazyNamedDimsArray) = abstracttrees_children_lazy(a) +TermInterface.sorted_children(a::LazyNamedDimsArray) = sorted_children_lazy(a) +ismul(a::LazyNamedDimsArray) = ismul_lazy(a) +AbstractTrees.nodevalue(a::LazyNamedDimsArray) = nodevalue_lazy(a) +Base.Broadcast.materialize(a::LazyNamedDimsArray) = materialize_lazy(a) +Base.copy(a::LazyNamedDimsArray) = copy_lazy(a) +Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = equals_lazy(a1, a2) +Base.hash(a::LazyNamedDimsArray, h::UInt64) = hash_lazy(a, h) +map_arguments(f, a::LazyNamedDimsArray) = map_arguments_lazy(f, a) +substitute(a::LazyNamedDimsArray, substitutions) = substitute_lazy(a, substitutions) +AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a) +printnode_nameddims(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a) +Base.show(io::IO, a::LazyNamedDimsArray) = show_lazy(io, a) +Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) = show_lazy(io, mime, a) +Base.:*(a::LazyNamedDimsArray) = mul_lazy(a) +Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = mul_lazy(a1, a2) +Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = add_lazy(a1, a2) +Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = sub_lazy(a1, a2) +Base.:*(a1::Number, a2::LazyNamedDimsArray) = mul_lazy(a1, a2) +Base.:*(a1::LazyNamedDimsArray, a2::Number) = mul_lazy(a1, a2) +Base.:/(a1::LazyNamedDimsArray, a2::Number) = div_lazy(a1, a2) +Base.:-(a::LazyNamedDimsArray) = sub_lazy(a) struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} name::Name @@ -280,6 +369,17 @@ Base.size(a::SymbolicArray) = length.(axes(a)) function Base.:(==)(a::SymbolicArray, b::SymbolicArray) return symname(a) == symname(b) && axes(a) == axes(b) end +function Base.hash(a::SymbolicArray, h::UInt64) + h = hash(:SymbolicArray, h) + h = hash(symname(a), h) + return hash(size(a), h) +end +function Base.getindex(a::SymbolicArray{<:Any, N}, I::Vararg{Int, N}) where {N} + return error("Indexing into SymbolicArray not supported.") +end +function Base.setindex!(a::SymbolicArray{<:Any, N}, value, I::Vararg{Int, N}) where {N} + return error("Indexing into SymbolicArray not supported.") +end function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) Base.summary(io, a) println(io, ":") @@ -300,24 +400,19 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = function symnameddims(name) return lazy(NamedDimsArray(SymbolicArray(name), ())) end -function printnode(io::IO, a::SymbolicNamedDimsArray) +function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) print(io, symname(dename(a))) if ndims(a) > 0 print(io, "[", join(dimnames(a), ","), "]") end return nothing end +printnode_nameddims(io::IO, a::SymbolicNamedDimsArray) = AbstractTrees.printnode(io, a) function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) return issetequal(inds(a), inds(b)) && dename(a) == dename(b) end -function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) - return lazy(a) * lazy(b) -end -function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) - return lazy(a) * b -end -function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) - return a * lazy(b) -end +Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) = lazy(a) * lazy(b) +Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) = lazy(a) * b +Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) = a * lazy(b) end diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index 40735ec..cc86fdc 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,27 +1,24 @@ using AbstractTrees: AbstractTrees, print_tree, printnode using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: - LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, symnameddims -using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims -using TermInterface: - arguments, - arity, - children, - head, - iscall, - isexpr, - maketerm, - operation, - sorted_arguments, - sorted_children +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, LazyNamedDimsArray, + Mul, SymbolicArray, ismul, lazy, substitute, symnameddims +using NamedDimsArrays: NamedDimsArray, @names, dename, dimnames, inds, nameddims, namedoneto +using TermInterface: arguments, arity, children, head, iscall, isexpr, maketerm, operation, + sorted_arguments, sorted_children using Test: @test, @test_throws, @testset using WrappedUnions: unwrap @testset "LazyNamedDimsArrays" begin + function sprint_namespaced(x) + context = (:module => LazyNamedDimsArrays) + module_prefix = "ITensorNetworksNext.LazyNamedDimsArrays." + return replace(sprint(show, MIME"text/plain"(), x; context), module_prefix => "") + end @testset "Basics" begin - a1 = nameddims(randn(2, 2), (:i, :j)) - a2 = nameddims(randn(2, 2), (:j, :k)) - a3 = nameddims(randn(2, 2), (:k, :l)) + i, j, k, l = namedoneto.(2, (:i, :j, :k, :l)) + a1 = randn(i, j) + a2 = randn(j, k) + a3 = randn(k, l) l1, l2, l3 = lazy.((a1, a2, a3)) for li in (l1, l2, l3) @test li isa LazyNamedDimsArray @@ -62,8 +59,8 @@ using WrappedUnions: unwrap @test sprint(show, l1) == sprint(show, a1) # TODO: Fix this test, it is basically correct but the type parameters # print in a different way. - # @test sprint(show, MIME"text/plain"(), l1) == - # replace(sprint(show, MIME"text/plain"(), a1), "NamedDimsArray" => "LazyNamedDimsArray") + # @test sprint_namespaced(l1) == + # replace(sprint_namespaced(a1), "NamedDimsArray" => "LazyNamedDimsArray") @test sprint(printnode, l1) == "[:i, :j]" @test sprint(print_tree, l1) == "[:i, :j]\n" @@ -81,29 +78,45 @@ using WrappedUnions: unwrap @test AbstractTrees.children(l) == [l1 * l2, l3] @test AbstractTrees.nodevalue(l) ≡ * @test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" - @test sprint(show, MIME"text/plain"(), l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint_namespaced(l) == + "named(Base.OneTo(2), :i)×named(Base.OneTo(2), :l) " * + "LazyNamedDimsArray{Float64, …}:\n(([:i, :j] * [:j, :k]) * [:k, :l])" @test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" @test sprint(print_tree, l) == - "(([:i, :j] * [:j, :k]) * [:k, :l])\n├─ ([:i, :j] * [:j, :k])\n│ ├─ [:i, :j]\n│ └─ [:j, :k]\n└─ [:k, :l]\n" + "(([:i, :j] * [:j, :k]) * [:k, :l])\n" * + "├─ ([:i, :j] * [:j, :k])\n" * + "│ ├─ [:i, :j]\n│ └─ [:j, :k]\n" * + "└─ [:k, :l]\n" end @testset "symnameddims" begin - a = symnameddims(:a) - b = symnameddims(:b) - c = symnameddims(:c) - @test a isa LazyNamedDimsArray - @test unwrap(a) isa NamedDimsArray - @test dename(a) isa SymbolicArray - @test dename(unwrap(a)) isa SymbolicArray - @test dename(unwrap(a)) == SymbolicArray(:a) - @test inds(a) == () - @test dimnames(a) == () + a1, a2, a3 = symnameddims.((:a1, :a2, :a3)) + @test a1 isa LazyNamedDimsArray + @test unwrap(a1) isa NamedDimsArray + @test dename(a1) isa SymbolicArray + @test dename(unwrap(a1)) isa SymbolicArray + @test dename(unwrap(a1)) == SymbolicArray(:a1) + @test inds(a1) == () + @test dimnames(a1) == () - ex = a * b * c + ex = a1 * a2 * a3 @test copy(ex) == ex - @test arguments(ex) == [a * b, c] + @test arguments(ex) == [a1 * a2, a3] @test operation(ex) ≡ * - @test sprint(show, ex) == "((a * b) * c)" - @test sprint(show, MIME"text/plain"(), ex) == "((a * b) * c)" + @test sprint(show, ex) == "((a1 * a2) * a3)" + @test sprint_namespaced(ex) == + "0-dimensional LazyNamedDimsArray{Any, …}:\n((a1 * a2) * a3)" + end + + @testset "substitute" begin + s = symnameddims.((:a1, :a2, :a3)) + i = @names i[1:4] + a = (randn(2, 2)[i[1], i[2]], randn(2, 2)[i[2], i[3]], randn(2, 2)[i[3], i[4]]) + l = lazy.(a) + + seq = s[1] * (s[2] * s[3]) + net = substitute(seq, s .=> l) + @test net == l[1] * (l[2] * l[3]) + @test arguments(net) == [l[1], l[2] * l[3]] end end