Skip to content

Commit 0c19c8b

Browse files
authored
Better printing, equality, symbolic arrays (#11)
1 parent 8a6158e commit 0c19c8b

File tree

4 files changed

+191
-19
lines changed

4 files changed

+191
-19
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
7+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5"
910
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
@@ -19,6 +20,7 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
1920
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
2021

2122
[compat]
23+
AbstractTrees = "0.4.5"
2224
Adapt = "4.3"
2325
BackendSelection = "0.1.6"
2426
DataGraphs = "0.2.7"

src/lazynameddimsarrays.jl

Lines changed: 145 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
module LazyNamedDimsArrays
22

3+
using AbstractTrees: AbstractTrees
34
using WrappedUnions: @wrapped, unwrap
45
using NamedDimsArrays:
56
NamedDimsArrays,
67
AbstractNamedDimsArray,
78
AbstractNamedDimsArrayStyle,
9+
NamedDimsArray,
810
dename,
11+
dimnames,
912
inds
1013
using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments
1114

15+
# Custom version of `AbstractTrees.printnode` to
16+
# avoid type piracy when overloading on `AbstractNamedDimsArray`.
17+
printnode(io::IO, x) = AbstractTrees.printnode(io, x)
18+
function printnode(io::IO, a::AbstractNamedDimsArray)
19+
show(io, collect(dimnames(a)))
20+
return nothing
21+
end
22+
1223
struct Mul{A}
1324
arguments::Vector{A}
1425
end
@@ -21,6 +32,13 @@ TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args)
2132
TermInterface.operation(m::Mul) = *
2233
TermInterface.sorted_arguments(m::Mul) = arguments(m)
2334
TermInterface.sorted_children(m::Mul) = sorted_arguments(a)
35+
ismul(x) = false
36+
ismul(m::Mul) = true
37+
function Base.show(io::IO, m::Mul)
38+
args = map(arg -> sprint(printnode, arg), arguments(m))
39+
print(io, "(", join(args, " $(operation(m)) "), ")")
40+
return nothing
41+
end
2442

2543
@wrapped struct LazyNamedDimsArray{
2644
T, A <: AbstractNamedDimsArray{T},
@@ -30,30 +48,28 @@ end
3048

3149
function NamedDimsArrays.inds(a::LazyNamedDimsArray)
3250
u = unwrap(a)
33-
if u isa AbstractNamedDimsArray
51+
if !iscall(u)
3452
return inds(u)
35-
elseif u isa Mul
53+
elseif ismul(u)
3654
return mapreduce(inds, symdiff, arguments(u))
3755
else
3856
return error("Variant not supported.")
3957
end
4058
end
4159
function NamedDimsArrays.dename(a::LazyNamedDimsArray)
4260
u = unwrap(a)
43-
if u isa AbstractNamedDimsArray
61+
if !iscall(u)
4462
return dename(u)
45-
elseif u isa Mul
46-
return dename(materialize(a), inds(a))
4763
else
4864
return error("Variant not supported.")
4965
end
5066
end
5167

5268
function TermInterface.arguments(a::LazyNamedDimsArray)
5369
u = unwrap(a)
54-
if u isa AbstractNamedDimsArray
70+
if !iscall(u)
5571
return error("No arguments.")
56-
elseif u isa Mul
72+
elseif ismul(u)
5773
return arguments(u)
5874
else
5975
return error("Variant not supported.")
@@ -75,24 +91,24 @@ function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata
7591
if head *
7692
return LazyNamedDimsArray(maketerm(Mul, head, args, metadata))
7793
else
78-
return error("Only product terms supported right now.")
94+
return error("Only mul supported right now.")
7995
end
8096
end
8197
function TermInterface.operation(a::LazyNamedDimsArray)
8298
u = unwrap(a)
83-
if u isa AbstractNamedDimsArray
99+
if !iscall(u)
84100
return error("No operation.")
85-
elseif u isa Mul
101+
elseif ismul(u)
86102
return operation(u)
87103
else
88104
return error("Variant not supported.")
89105
end
90106
end
91107
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
92108
u = unwrap(a)
93-
if u isa AbstractNamedDimsArray
109+
if !iscall(u)
94110
return error("No arguments.")
95-
elseif u isa Mul
111+
elseif ismul(u)
96112
return sorted_arguments(u)
97113
else
98114
return error("Variant not supported.")
@@ -101,25 +117,75 @@ end
101117
function TermInterface.sorted_children(a::LazyNamedDimsArray)
102118
return sorted_arguments(a)
103119
end
120+
ismul(a::LazyNamedDimsArray) = ismul(unwrap(a))
121+
122+
function AbstractTrees.children(a::LazyNamedDimsArray)
123+
if !iscall(a)
124+
return ()
125+
else
126+
return arguments(a)
127+
end
128+
end
129+
function AbstractTrees.nodevalue(a::LazyNamedDimsArray)
130+
if !iscall(a)
131+
return unwrap(a)
132+
else
133+
return operation(a)
134+
end
135+
end
104136

105137
using Base.Broadcast: materialize
106138
function Base.Broadcast.materialize(a::LazyNamedDimsArray)
107139
u = unwrap(a)
108-
if u isa AbstractNamedDimsArray
140+
if !iscall(u)
109141
return u
110-
elseif u isa Mul
142+
elseif ismul(u)
111143
return mapfoldl(materialize, operation(u), arguments(u))
112144
else
113145
return error("Variant not supported.")
114146
end
115147
end
116148
Base.copy(a::LazyNamedDimsArray) = materialize(a)
117149

150+
function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
151+
u1, u2 = unwrap.((a1, a2))
152+
if !iscall(u1) && !iscall(u2)
153+
return u1 == u2
154+
elseif ismul(u1) && ismul(u2)
155+
return arguments(u1) == arguments(u2)
156+
else
157+
return false
158+
end
159+
end
160+
161+
function printnode(io::IO, a::LazyNamedDimsArray)
162+
return printnode(io, unwrap(a))
163+
end
164+
function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray)
165+
return printnode(io, a)
166+
end
167+
function Base.show(io::IO, a::LazyNamedDimsArray)
168+
if !iscall(a)
169+
return show(io, unwrap(a))
170+
else
171+
return printnode(io, a)
172+
end
173+
end
174+
function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray)
175+
if !iscall(a)
176+
@invoke show(io, mime, a::AbstractNamedDimsArray)
177+
return nothing
178+
else
179+
show(io, a)
180+
return nothing
181+
end
182+
end
183+
118184
function Base.:*(a::LazyNamedDimsArray)
119185
u = unwrap(a)
120-
if u isa AbstractNamedDimsArray
186+
if !iscall(u)
121187
return LazyNamedDimsArray(Mul([lazy(u)]))
122-
elseif u isa Mul
188+
elseif ismul(u)
123189
return a
124190
else
125191
return error("Variant not supported.")
@@ -191,4 +257,67 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
191257
return -a
192258
end
193259

260+
struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N}
261+
name::Name
262+
axes::Axes
263+
function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T}
264+
N = length(ax)
265+
return new{T, N, typeof(name), typeof(ax)}(name, ax)
266+
end
267+
end
268+
function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}})
269+
return SymbolicArray{Any}(name, ax)
270+
end
271+
function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T}
272+
return SymbolicArray{T}(name, ax)
273+
end
274+
function SymbolicArray(name, ax::AbstractUnitRange...)
275+
return SymbolicArray{Any}(name, ax)
276+
end
277+
symname(a::SymbolicArray) = getfield(a, :name)
278+
Base.axes(a::SymbolicArray) = getfield(a, :axes)
279+
Base.size(a::SymbolicArray) = length.(axes(a))
280+
function Base.:(==)(a::SymbolicArray, b::SymbolicArray)
281+
return symname(a) == symname(b) && axes(a) == axes(b)
282+
end
283+
function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray)
284+
Base.summary(io, a)
285+
println(io, ":")
286+
print(io, repr(symname(a)))
287+
return nothing
288+
end
289+
function Base.show(io::IO, a::SymbolicArray)
290+
print(io, "SymbolicArray(", symname(a), ", ", size(a), ")")
291+
return nothing
292+
end
293+
using AbstractTrees: AbstractTrees
294+
function AbstractTrees.printnode(io::IO, a::SymbolicArray)
295+
print(io, repr(symname(a)))
296+
return nothing
297+
end
298+
const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
299+
NamedDimsArray{T, N, Parent, DimNames}
300+
function symnameddims(name)
301+
return lazy(NamedDimsArray(SymbolicArray(name), ()))
302+
end
303+
function printnode(io::IO, a::SymbolicNamedDimsArray)
304+
print(io, symname(dename(a)))
305+
if ndims(a) > 0
306+
print(io, "[", join(dimnames(a), ","), "]")
307+
end
308+
return nothing
309+
end
310+
function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
311+
return issetequal(inds(a), inds(b)) && dename(a) == dename(b)
312+
end
313+
function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
314+
return lazy(a) * lazy(b)
315+
end
316+
function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray)
317+
return lazy(a) * b
318+
end
319+
function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray)
320+
return a * lazy(b)
321+
end
322+
194323
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
23
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
34
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
45
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -13,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1314
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
1415

1516
[compat]
17+
AbstractTrees = "0.4.5"
1618
Aqua = "0.8.14"
1719
Dictionaries = "0.4.5"
1820
Graphs = "1.13.1"

test/test_lazynameddimsarrays.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
using AbstractTrees: AbstractTrees, print_tree, printnode
12
using Base.Broadcast: materialize
2-
using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy
3-
using NamedDimsArrays: NamedDimsArray, inds, nameddims
3+
using ITensorNetworksNext.LazyNamedDimsArrays:
4+
LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, symnameddims
5+
using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims
46
using TermInterface:
57
arguments,
68
arity,
@@ -33,6 +35,7 @@ using WrappedUnions: unwrap
3335
@test materialize(l) a1 * a2 * a3
3436
@test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...))
3537
@test unwrap(l) isa Mul
38+
@test ismul(unwrap(l))
3639
@test unwrap(l).arguments == [l1 * l2, l3]
3740
# TermInterface.jl
3841
@test operation(unwrap(l)) *
@@ -54,6 +57,15 @@ using WrappedUnions: unwrap
5457
@test_throws ErrorException operation(l1)
5558
@test_throws ErrorException sorted_arguments(l1)
5659
@test_throws ErrorException sorted_children(l1)
60+
@test AbstractTrees.children(l1) ()
61+
@test AbstractTrees.nodevalue(l1) a1
62+
@test sprint(show, l1) == sprint(show, a1)
63+
# TODO: Fix this test, it is basically correct but the type parameters
64+
# print in a different way.
65+
# @test sprint(show, MIME"text/plain"(), l1) ==
66+
# replace(sprint(show, MIME"text/plain"(), a1), "NamedDimsArray" => "LazyNamedDimsArray")
67+
@test sprint(printnode, l1) == "[:i, :j]"
68+
@test sprint(print_tree, l1) == "[:i, :j]\n"
5769

5870
l = l1 * l2 * l3
5971
@test arguments(l) == [l1 * l2, l3]
@@ -66,5 +78,32 @@ using WrappedUnions: unwrap
6678
@test operation(l) *
6779
@test sorted_arguments(l) == [l1 * l2, l3]
6880
@test sorted_children(l) == [l1 * l2, l3]
81+
@test AbstractTrees.children(l) == [l1 * l2, l3]
82+
@test AbstractTrees.nodevalue(l) *
83+
@test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
84+
@test sprint(show, MIME"text/plain"(), l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
85+
@test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
86+
@test sprint(print_tree, l) ==
87+
"(([:i, :j] * [:j, :k]) * [:k, :l])\n├─ ([:i, :j] * [:j, :k])\n│ ├─ [:i, :j]\n│ └─ [:j, :k]\n└─ [:k, :l]\n"
88+
end
89+
90+
@testset "symnameddims" begin
91+
a = symnameddims(:a)
92+
b = symnameddims(:b)
93+
c = symnameddims(:c)
94+
@test a isa LazyNamedDimsArray
95+
@test unwrap(a) isa NamedDimsArray
96+
@test dename(a) isa SymbolicArray
97+
@test dename(unwrap(a)) isa SymbolicArray
98+
@test dename(unwrap(a)) == SymbolicArray(:a)
99+
@test inds(a) == ()
100+
@test dimnames(a) == ()
101+
102+
ex = a * b * c
103+
@test copy(ex) == ex
104+
@test arguments(ex) == [a * b, c]
105+
@test operation(ex) *
106+
@test sprint(show, ex) == "((a * b) * c)"
107+
@test sprint(show, MIME"text/plain"(), ex) == "((a * b) * c)"
69108
end
70109
end

0 commit comments

Comments
 (0)