Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "ITensorNetworksNext"
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.7"
version = "0.1.8"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
Expand All @@ -19,6 +20,7 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"

[compat]
AbstractTrees = "0.4.5"
Adapt = "4.3"
BackendSelection = "0.1.6"
DataGraphs = "0.2.7"
Expand Down
161 changes: 145 additions & 16 deletions src/lazynameddimsarrays.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
module LazyNamedDimsArrays

using AbstractTrees: AbstractTrees
using WrappedUnions: @wrapped, unwrap
using NamedDimsArrays:
NamedDimsArrays,
AbstractNamedDimsArray,
AbstractNamedDimsArrayStyle,
NamedDimsArray,
dename,
dimnames,
inds
using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments

# Custom version of `AbstractTrees.printnode` to
# avoid type piracy when overloading on `AbstractNamedDimsArray`.
printnode(io::IO, x) = AbstractTrees.printnode(io, x)
function printnode(io::IO, a::AbstractNamedDimsArray)
show(io, collect(dimnames(a)))
return nothing
end

struct Mul{A}
arguments::Vector{A}
end
Expand All @@ -21,6 +32,13 @@ TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args)
TermInterface.operation(m::Mul) = *
TermInterface.sorted_arguments(m::Mul) = arguments(m)
TermInterface.sorted_children(m::Mul) = sorted_arguments(a)
ismul(x) = false
ismul(m::Mul) = true
function Base.show(io::IO, m::Mul)
args = map(arg -> sprint(printnode, arg), arguments(m))
print(io, "(", join(args, " $(operation(m)) "), ")")
return nothing
end

@wrapped struct LazyNamedDimsArray{
T, A <: AbstractNamedDimsArray{T},
Expand All @@ -30,30 +48,28 @@ end

function NamedDimsArrays.inds(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
if !iscall(u)
return inds(u)
elseif u isa Mul
elseif ismul(u)
return mapreduce(inds, symdiff, arguments(u))
else
return error("Variant not supported.")
end
end
function NamedDimsArrays.dename(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
if !iscall(u)
return dename(u)
elseif u isa Mul
return dename(materialize(a), inds(a))
else
return error("Variant not supported.")
end
end

function TermInterface.arguments(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
if !iscall(u)
return error("No arguments.")
elseif u isa Mul
elseif ismul(u)
return arguments(u)
else
return error("Variant not supported.")
Expand All @@ -75,24 +91,24 @@ function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata
if head ≡ *
return LazyNamedDimsArray(maketerm(Mul, head, args, metadata))
else
return error("Only product terms supported right now.")
return error("Only mul supported right now.")
end
end
function TermInterface.operation(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
if !iscall(u)
return error("No operation.")
elseif u isa Mul
elseif ismul(u)
return operation(u)
else
return error("Variant not supported.")
end
end
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
if !iscall(u)
return error("No arguments.")
elseif u isa Mul
elseif ismul(u)
return sorted_arguments(u)
else
return error("Variant not supported.")
Expand All @@ -101,25 +117,75 @@ end
function TermInterface.sorted_children(a::LazyNamedDimsArray)
return sorted_arguments(a)
end
ismul(a::LazyNamedDimsArray) = ismul(unwrap(a))

function AbstractTrees.children(a::LazyNamedDimsArray)
if !iscall(a)
return ()
else
return arguments(a)
end
end
function AbstractTrees.nodevalue(a::LazyNamedDimsArray)
if !iscall(a)
return unwrap(a)
else
return operation(a)
end
end

using Base.Broadcast: materialize
function Base.Broadcast.materialize(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
if !iscall(u)
return u
elseif u isa Mul
elseif ismul(u)
return mapfoldl(materialize, operation(u), arguments(u))
else
return error("Variant not supported.")
end
end
Base.copy(a::LazyNamedDimsArray) = materialize(a)

function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
u1, u2 = unwrap.((a1, a2))
if !iscall(u1) && !iscall(u2)
return u1 == u2
elseif ismul(u1) && ismul(u2)
return arguments(u1) == arguments(u2)
else
return false
end
end

function printnode(io::IO, a::LazyNamedDimsArray)
return printnode(io, unwrap(a))
end
function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray)
return printnode(io, a)
end
function Base.show(io::IO, a::LazyNamedDimsArray)
if !iscall(a)
return show(io, unwrap(a))
else
return printnode(io, a)
end
end
function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray)
if !iscall(a)
@invoke show(io, mime, a::AbstractNamedDimsArray)
return nothing
else
show(io, a)
return nothing
end
end

function Base.:*(a::LazyNamedDimsArray)
u = unwrap(a)
if u isa AbstractNamedDimsArray
if !iscall(u)
return LazyNamedDimsArray(Mul([lazy(u)]))
elseif u isa Mul
elseif ismul(u)
return a
else
return error("Variant not supported.")
Expand Down Expand Up @@ -191,4 +257,67 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
return -a
end

struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N}
name::Name
axes::Axes
function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T}
N = length(ax)
return new{T, N, typeof(name), typeof(ax)}(name, ax)
end
end
function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}})
return SymbolicArray{Any}(name, ax)
end
function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T}
return SymbolicArray{T}(name, ax)
end
function SymbolicArray(name, ax::AbstractUnitRange...)
return SymbolicArray{Any}(name, ax)
end
symname(a::SymbolicArray) = getfield(a, :name)
Base.axes(a::SymbolicArray) = getfield(a, :axes)
Base.size(a::SymbolicArray) = length.(axes(a))
function Base.:(==)(a::SymbolicArray, b::SymbolicArray)
return symname(a) == symname(b) && axes(a) == axes(b)
end
function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray)
Base.summary(io, a)
println(io, ":")
print(io, repr(symname(a)))
return nothing
end
function Base.show(io::IO, a::SymbolicArray)
print(io, "SymbolicArray(", symname(a), ", ", size(a), ")")
return nothing
end
using AbstractTrees: AbstractTrees
function AbstractTrees.printnode(io::IO, a::SymbolicArray)
print(io, repr(symname(a)))
return nothing
end
const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
NamedDimsArray{T, N, Parent, DimNames}
function symnameddims(name)
return lazy(NamedDimsArray(SymbolicArray(name), ()))
end
function printnode(io::IO, a::SymbolicNamedDimsArray)
print(io, symname(dename(a)))
if ndims(a) > 0
print(io, "[", join(dimnames(a), ","), "]")
end
return nothing
end
function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
return issetequal(inds(a), inds(b)) && dename(a) == dename(b)
end
function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
return lazy(a) * lazy(b)
end
function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray)
return lazy(a) * b
end
function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray)
return a * lazy(b)
end

end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand All @@ -13,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"

[compat]
AbstractTrees = "0.4.5"
Aqua = "0.8.14"
Dictionaries = "0.4.5"
Graphs = "1.13.1"
Expand Down
43 changes: 41 additions & 2 deletions test/test_lazynameddimsarrays.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using AbstractTrees: AbstractTrees, print_tree, printnode
using Base.Broadcast: materialize
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy
using NamedDimsArrays: NamedDimsArray, inds, nameddims
using ITensorNetworksNext.LazyNamedDimsArrays:
LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, symnameddims
using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims
using TermInterface:
arguments,
arity,
Expand Down Expand Up @@ -33,6 +35,7 @@ using WrappedUnions: unwrap
@test materialize(l) ≈ a1 * a2 * a3
@test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...))
@test unwrap(l) isa Mul
@test ismul(unwrap(l))
@test unwrap(l).arguments == [l1 * l2, l3]
# TermInterface.jl
@test operation(unwrap(l)) ≡ *
Expand All @@ -54,6 +57,15 @@ using WrappedUnions: unwrap
@test_throws ErrorException operation(l1)
@test_throws ErrorException sorted_arguments(l1)
@test_throws ErrorException sorted_children(l1)
@test AbstractTrees.children(l1) ≡ ()
@test AbstractTrees.nodevalue(l1) ≡ a1
@test sprint(show, l1) == sprint(show, a1)
# TODO: Fix this test, it is basically correct but the type parameters
# print in a different way.
# @test sprint(show, MIME"text/plain"(), l1) ==
# replace(sprint(show, MIME"text/plain"(), a1), "NamedDimsArray" => "LazyNamedDimsArray")
@test sprint(printnode, l1) == "[:i, :j]"
@test sprint(print_tree, l1) == "[:i, :j]\n"

l = l1 * l2 * l3
@test arguments(l) == [l1 * l2, l3]
Expand All @@ -66,5 +78,32 @@ using WrappedUnions: unwrap
@test operation(l) ≡ *
@test sorted_arguments(l) == [l1 * l2, l3]
@test sorted_children(l) == [l1 * l2, l3]
@test AbstractTrees.children(l) == [l1 * l2, l3]
@test AbstractTrees.nodevalue(l) ≡ *
@test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
@test sprint(show, MIME"text/plain"(), l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
@test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
@test sprint(print_tree, l) ==
"(([:i, :j] * [:j, :k]) * [:k, :l])\n├─ ([:i, :j] * [:j, :k])\n│ ├─ [:i, :j]\n│ └─ [:j, :k]\n└─ [:k, :l]\n"
end

@testset "symnameddims" begin
a = symnameddims(:a)
b = symnameddims(:b)
c = symnameddims(:c)
@test a isa LazyNamedDimsArray
@test unwrap(a) isa NamedDimsArray
@test dename(a) isa SymbolicArray
@test dename(unwrap(a)) isa SymbolicArray
@test dename(unwrap(a)) == SymbolicArray(:a)
@test inds(a) == ()
@test dimnames(a) == ()

ex = a * b * c
@test copy(ex) == ex
@test arguments(ex) == [a * b, c]
@test operation(ex) ≡ *
@test sprint(show, ex) == "((a * b) * c)"
@test sprint(show, MIME"text/plain"(), ex) == "((a * b) * c)"
end
end
Loading