Skip to content

Commit 638b94b

Browse files
committed
Better printing, equality, symbolic arrays
1 parent 8a6158e commit 638b94b

File tree

6 files changed

+197
-19
lines changed

6 files changed

+197
-19
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
7+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
910
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
@@ -19,6 +20,7 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
1920
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
2021

2122
[compat]
23+
AbstractTrees = "0.4.5"
2224
Adapt = "4.3"
2325
BackendSelection = "0.1.6"
2426
DataGraphs = "0.2.7"

src/ITensorNetworksNext.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ITensorNetworksNext
22

3+
include("symbolicarrays.jl")
34
include("lazynameddimsarrays.jl")
45
include("abstracttensornetwork.jl")
56
include("tensornetwork.jl")

src/lazynameddimsarrays.jl

Lines changed: 106 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,49 @@
11
module LazyNamedDimsArrays
22

3+
using AbstractTrees: AbstractTrees
34
using WrappedUnions: @wrapped, unwrap
45
using NamedDimsArrays:
56
NamedDimsArrays,
67
AbstractNamedDimsArray,
78
AbstractNamedDimsArrayStyle,
9+
NamedDimsArray,
810
dename,
11+
dimnames,
912
inds
13+
using ..SymbolicArrays: SymbolicArrays, SymbolicArray
1014
using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments
1115

16+
const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
17+
NamedDimsArray{T, N, Parent, DimNames}
18+
function symnameddims(name)
19+
return lazy(NamedDimsArray(SymbolicArray(name), ()))
20+
end
21+
function printnode(io::IO, a::SymbolicNamedDimsArray)
22+
print(io, SymbolicArrays.name(dename(a)))
23+
print(io, "[", join(dimnames(a), ","), "]")
24+
return nothing
25+
end
26+
function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
27+
return issetequal(inds(a), inds(b)) && dename(a) == dename(b)
28+
end
29+
function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
30+
return lazy(a) * lazy(b)
31+
end
32+
function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray)
33+
return lazy(a) * b
34+
end
35+
function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray)
36+
return a * lazy(b)
37+
end
38+
39+
# Custom version of `AbstractTrees.printnode` to
40+
# avoid type piracy when overloading on `AbstractNamedDimsArray`.
41+
printnode(io::IO, x) = AbstractTrees.printnode(io, x)
42+
function printnode(io::IO, a::AbstractNamedDimsArray)
43+
show(io, collect(dimnames(a)))
44+
return nothing
45+
end
46+
1247
struct Mul{A}
1348
arguments::Vector{A}
1449
end
@@ -21,6 +56,13 @@ TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args)
2156
TermInterface.operation(m::Mul) = *
2257
TermInterface.sorted_arguments(m::Mul) = arguments(m)
2358
TermInterface.sorted_children(m::Mul) = sorted_arguments(a)
59+
ismul(x) = false
60+
ismul(m::Mul) = true
61+
function Base.show(io::IO, m::Mul)
62+
args = map(arg -> sprint(printnode, arg), arguments(m))
63+
print(io, "(", join(args, " $(operation(m)) "), ")")
64+
return nothing
65+
end
2466

2567
@wrapped struct LazyNamedDimsArray{
2668
T, A <: AbstractNamedDimsArray{T},
@@ -30,30 +72,28 @@ end
3072

3173
function NamedDimsArrays.inds(a::LazyNamedDimsArray)
3274
u = unwrap(a)
33-
if u isa AbstractNamedDimsArray
75+
if !iscall(u)
3476
return inds(u)
35-
elseif u isa Mul
77+
elseif ismul(u)
3678
return mapreduce(inds, symdiff, arguments(u))
3779
else
3880
return error("Variant not supported.")
3981
end
4082
end
4183
function NamedDimsArrays.dename(a::LazyNamedDimsArray)
4284
u = unwrap(a)
43-
if u isa AbstractNamedDimsArray
85+
if !iscall(u)
4486
return dename(u)
45-
elseif u isa Mul
46-
return dename(materialize(a), inds(a))
4787
else
4888
return error("Variant not supported.")
4989
end
5090
end
5191

5292
function TermInterface.arguments(a::LazyNamedDimsArray)
5393
u = unwrap(a)
54-
if u isa AbstractNamedDimsArray
94+
if !iscall(u)
5595
return error("No arguments.")
56-
elseif u isa Mul
96+
elseif ismul(u)
5797
return arguments(u)
5898
else
5999
return error("Variant not supported.")
@@ -75,24 +115,24 @@ function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata
75115
if head *
76116
return LazyNamedDimsArray(maketerm(Mul, head, args, metadata))
77117
else
78-
return error("Only product terms supported right now.")
118+
return error("Only mul supported right now.")
79119
end
80120
end
81121
function TermInterface.operation(a::LazyNamedDimsArray)
82122
u = unwrap(a)
83-
if u isa AbstractNamedDimsArray
123+
if !iscall(u)
84124
return error("No operation.")
85-
elseif u isa Mul
125+
elseif ismul(u)
86126
return operation(u)
87127
else
88128
return error("Variant not supported.")
89129
end
90130
end
91131
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
92132
u = unwrap(a)
93-
if u isa AbstractNamedDimsArray
133+
if !iscall(u)
94134
return error("No arguments.")
95-
elseif u isa Mul
135+
elseif ismul(u)
96136
return sorted_arguments(u)
97137
else
98138
return error("Variant not supported.")
@@ -101,25 +141,75 @@ end
101141
function TermInterface.sorted_children(a::LazyNamedDimsArray)
102142
return sorted_arguments(a)
103143
end
144+
ismul(a::LazyNamedDimsArray) = ismul(unwrap(a))
145+
146+
function AbstractTrees.children(a::LazyNamedDimsArray)
147+
if !iscall(a)
148+
return ()
149+
else
150+
return arguments(a)
151+
end
152+
end
153+
function AbstractTrees.nodevalue(a::LazyNamedDimsArray)
154+
if !iscall(a)
155+
return unwrap(a)
156+
else
157+
return operation(a)
158+
end
159+
end
104160

105161
using Base.Broadcast: materialize
106162
function Base.Broadcast.materialize(a::LazyNamedDimsArray)
107163
u = unwrap(a)
108-
if u isa AbstractNamedDimsArray
164+
if !iscall(u)
109165
return u
110-
elseif u isa Mul
166+
elseif ismul(u)
111167
return mapfoldl(materialize, operation(u), arguments(u))
112168
else
113169
return error("Variant not supported.")
114170
end
115171
end
116172
Base.copy(a::LazyNamedDimsArray) = materialize(a)
117173

174+
function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
175+
u1, u2 = unwrap.((a1, a2))
176+
if !iscall(u1) && !iscall(u2)
177+
return u1 == u2
178+
elseif ismul(u1) && ismul(u2)
179+
return arguments(u1) == arguments(u2)
180+
else
181+
return false
182+
end
183+
end
184+
185+
function printnode(io::IO, a::LazyNamedDimsArray)
186+
return printnode(io, unwrap(a))
187+
end
188+
function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray)
189+
return printnode(io, a)
190+
end
191+
function Base.show(io::IO, a::LazyNamedDimsArray)
192+
if !iscall(a)
193+
return show(io, unwrap(a))
194+
else
195+
return printnode(io, a)
196+
end
197+
end
198+
function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray)
199+
if !iscall(a)
200+
@invoke show(io, mime, a::AbstractNamedDimsArray)
201+
return nothing
202+
else
203+
show(io, a)
204+
return nothing
205+
end
206+
end
207+
118208
function Base.:*(a::LazyNamedDimsArray)
119209
u = unwrap(a)
120-
if u isa AbstractNamedDimsArray
210+
if !iscall(u)
121211
return LazyNamedDimsArray(Mul([lazy(u)]))
122-
elseif u isa Mul
212+
elseif ismul(u)
123213
return a
124214
else
125215
return error("Variant not supported.")

src/symbolicarrays.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
module SymbolicArrays
2+
3+
using AbstractTrees: AbstractTrees
4+
5+
struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N}
6+
name::Name
7+
axes::Axes
8+
function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T}
9+
N = length(ax)
10+
return new{T, N, typeof(name), typeof(ax)}(name, ax)
11+
end
12+
end
13+
function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}})
14+
return SymbolicArray{Any}(name, ax)
15+
end
16+
function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T}
17+
return SymbolicArray{T}(name, ax)
18+
end
19+
function SymbolicArray(name, ax::AbstractUnitRange...)
20+
return SymbolicArray{Any}(name, ax)
21+
end
22+
name(a::SymbolicArray) = getfield(a, :name)
23+
Base.axes(a::SymbolicArray) = getfield(a, :axes)
24+
Base.size(a::SymbolicArray) = length.(axes(a))
25+
function Base.:(==)(a::SymbolicArray, b::SymbolicArray)
26+
return name(a) == name(b) && axes(a) == axes(b)
27+
end
28+
function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray)
29+
Base.summary(io, a)
30+
println(io, ":")
31+
print(io, repr(name(a)))
32+
return nothing
33+
end
34+
function Base.show(io::IO, a::SymbolicArray)
35+
print(io, "SymbolicArray(", name(a), ", ", size(a), ")")
36+
return nothing
37+
end
38+
39+
function AbstractTrees.printnode(io::IO, a::SymbolicArray)
40+
print(io, repr(name(a)))
41+
return nothing
42+
end
43+
44+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
23
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
34
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
45
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"

test/test_lazynameddimsarrays.jl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
using AbstractTrees: AbstractTrees, print_tree, printnode
12
using Base.Broadcast: materialize
2-
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy
3-
using NamedDimsArrays: NamedDimsArray, inds, nameddims
3+
using ITensorNetworksNext.LazyNamedDimsArrays:
4+
LazyNamedDimsArray, Mul, ismul, lazy, symnameddims
5+
using ITensorNetworksNext.SymbolicArrays: SymbolicArray
6+
using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims
47
using TermInterface:
58
arguments,
69
arity,
@@ -33,6 +36,7 @@ using WrappedUnions: unwrap
3336
@test materialize(l) a1 * a2 * a3
3437
@test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...))
3538
@test unwrap(l) isa Mul
39+
@test ismul(unwrap(l))
3640
@test unwrap(l).arguments == [l1 * l2, l3]
3741
# TermInterface.jl
3842
@test operation(unwrap(l)) *
@@ -54,6 +58,15 @@ using WrappedUnions: unwrap
5458
@test_throws ErrorException operation(l1)
5559
@test_throws ErrorException sorted_arguments(l1)
5660
@test_throws ErrorException sorted_children(l1)
61+
@test AbstractTrees.children(l1) ()
62+
@test AbstractTrees.nodevalue(l1) a1
63+
@test sprint(show, l1) == sprint(show, a1)
64+
# TODO: Fix this test, it is basically correct but the type parameters
65+
# print in a different way.
66+
# @test sprint(show, MIME"text/plain"(), l1) ==
67+
# replace(sprint(show, MIME"text/plain"(), a1), "NamedDimsArray" => "LazyNamedDimsArray")
68+
@test sprint(printnode, l1) == "[:i, :j]"
69+
@test sprint(print_tree, l1) == "[:i, :j]\n"
5770

5871
l = l1 * l2 * l3
5972
@test arguments(l) == [l1 * l2, l3]
@@ -66,5 +79,32 @@ using WrappedUnions: unwrap
6679
@test operation(l) *
6780
@test sorted_arguments(l) == [l1 * l2, l3]
6881
@test sorted_children(l) == [l1 * l2, l3]
82+
@test AbstractTrees.children(l) == [l1 * l2, l3]
83+
@test AbstractTrees.nodevalue(l) *
84+
@test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
85+
@test sprint(show, MIME"text/plain"(), l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
86+
@test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
87+
@test sprint(print_tree, l) ==
88+
"(([:i, :j] * [:j, :k]) * [:k, :l])\n├─ ([:i, :j] * [:j, :k])\n│ ├─ [:i, :j]\n│ └─ [:j, :k]\n└─ [:k, :l]\n"
89+
end
90+
91+
@testset "symnameddims" begin
92+
a = symnameddims(:a)
93+
b = symnameddims(:b)
94+
c = symnameddims(:c)
95+
@test a isa LazyNamedDimsArray
96+
@test unwrap(a) isa NamedDimsArray
97+
@test dename(a) isa SymbolicArray
98+
@test dename(unwrap(a)) isa SymbolicArray
99+
@test dename(unwrap(a)) == SymbolicArray(:a)
100+
@test inds(a) == ()
101+
@test dimnames(a) == ()
102+
103+
ex = a * b * c
104+
@test copy(ex) == ex
105+
@test arguments(ex) == [a * b, c]
106+
@test operation(ex) *
107+
@test sprint(show, ex) == "((a[] * b[]) * c[])"
108+
@test sprint(show, MIME"text/plain"(), ex) == "((a[] * b[]) * c[])"
69109
end
70110
end

0 commit comments

Comments
 (0)