Skip to content

Commit 96a9ebf

Browse files
committed
More reorg
1 parent 2716c18 commit 96a9ebf

File tree

2 files changed

+54
-34
lines changed

2 files changed

+54
-34
lines changed

src/lazynameddimsarrays.jl

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ end
3737

3838
# Custom version of `AbstractTrees.printnode` to
3939
# avoid type piracy when overloading on `AbstractNamedDimsArray`.
40-
printnode(io::IO, x) = AbstractTrees.printnode(io, x)
41-
function printnode(io::IO, a::AbstractNamedDimsArray)
40+
printnode_nameddims(io::IO, x) = AbstractTrees.printnode(io, x)
41+
function printnode_nameddims(io::IO, a::AbstractNamedDimsArray)
4242
show(io, collect(dimnames(a)))
4343
return nothing
4444
end
@@ -149,6 +149,37 @@ function map_arguments_lazy(f, a)
149149
return error("Variant not supported.")
150150
end
151151
end
152+
function substitute_lazy(a, substitutions::AbstractDict)
153+
haskey(substitutions, a) && return substitutions[a]
154+
!iscall(a) && return a
155+
return map_arguments(arg -> substitute(arg, substitutions), a)
156+
end
157+
function substitute_lazy(a, substitutions)
158+
return substitute(a, Dict(substitutions))
159+
end
160+
function printnode_lazy(io, a)
161+
# Use `printnode_nameddims` to avoid type piracy,
162+
# since it overloads on `AbstractNamedDimsArray`.
163+
return printnode_nameddims(io, unwrap(a))
164+
end
165+
function show_lazy(io::IO, a)
166+
if !iscall(a)
167+
return show(io, unwrap(a))
168+
else
169+
return AbstractTrees.printnode(io, a)
170+
end
171+
end
172+
function show_lazy(io::IO, mime::MIME"text/plain", a)
173+
summary(io, a)
174+
println(io, ":")
175+
if !iscall(a)
176+
show(io, mime, unwrap(a))
177+
return nothing
178+
else
179+
show(io, a)
180+
return nothing
181+
end
182+
end
152183

153184
# Lazy broadcasting.
154185
struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end
@@ -185,7 +216,7 @@ TermInterface.head(m::Applied) = operation(m)
185216
TermInterface.iscall(m::Applied) = true
186217
TermInterface.isexpr(m::Applied) = iscall(m)
187218
function Base.show(io::IO, m::Applied)
188-
args = map(arg -> sprint(printnode, arg), arguments(m))
219+
args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(m))
189220
print(io, "(", join(args, " $(operation(m)) "), ")")
190221
return nothing
191222
end
@@ -301,37 +332,20 @@ end
301332
function map_arguments(f, a::LazyNamedDimsArray)
302333
return map_arguments_lazy(f, a)
303334
end
304-
305-
function substitute(a::LazyNamedDimsArray, substitutions::AbstractDict)
306-
haskey(substitutions, a) && return substitutions[a]
307-
!iscall(a) && return a
308-
return map_arguments(arg -> substitute(arg, substitutions), a)
309-
end
310335
function substitute(a::LazyNamedDimsArray, substitutions)
311-
return substitute(a, Dict(substitutions))
312-
end
313-
314-
function printnode(io::IO, a::LazyNamedDimsArray)
315-
return printnode(io, unwrap(a))
336+
return substitute_lazy(a, substitutions)
316337
end
317338
function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray)
318-
return printnode(io, a)
339+
return printnode_lazy(io, a)
340+
end
341+
function printnode_nameddims(io::IO, a::LazyNamedDimsArray)
342+
return printnode_lazy(io, a)
319343
end
320344
function Base.show(io::IO, a::LazyNamedDimsArray)
321-
if !iscall(a)
322-
return show(io, unwrap(a))
323-
else
324-
return printnode(io, a)
325-
end
345+
return show_lazy(io, a)
326346
end
327347
function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray)
328-
if !iscall(a)
329-
@invoke show(io, mime, a::AbstractNamedDimsArray)
330-
return nothing
331-
else
332-
show(io, a)
333-
return nothing
334-
end
348+
return show_lazy(io, mime, a)
335349
end
336350

337351
function Base.:*(a::LazyNamedDimsArray)
@@ -427,13 +441,16 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
427441
function symnameddims(name)
428442
return lazy(NamedDimsArray(SymbolicArray(name), ()))
429443
end
430-
function printnode(io::IO, a::SymbolicNamedDimsArray)
444+
function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray)
431445
print(io, symname(dename(a)))
432446
if ndims(a) > 0
433447
print(io, "[", join(dimnames(a), ","), "]")
434448
end
435449
return nothing
436450
end
451+
function printnode_nameddims(io::IO, a::SymbolicNamedDimsArray)
452+
return AbstractTrees.printnode(io, a)
453+
end
437454
function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
438455
return issetequal(inds(a), inds(b)) && dename(a) == dename(b)
439456
end

test/test_lazynameddimsarrays.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using AbstractTrees: AbstractTrees, print_tree, printnode
22
using Base.Broadcast: materialize
33
using ITensorNetworksNext.LazyNamedDimsArrays:
44
LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, substitute, symnameddims
5-
using NamedDimsArrays: NamedDimsArray, @names, dename, dimnames, inds, nameddims
5+
using NamedDimsArrays: NamedDimsArray, @names, dename, dimnames, inds, nameddims, namedoneto
66
using TermInterface:
77
arguments,
88
arity,
@@ -19,9 +19,10 @@ using WrappedUnions: unwrap
1919

2020
@testset "LazyNamedDimsArrays" begin
2121
@testset "Basics" begin
22-
a1 = nameddims(randn(2, 2), (:i, :j))
23-
a2 = nameddims(randn(2, 2), (:j, :k))
24-
a3 = nameddims(randn(2, 2), (:k, :l))
22+
i, j, k, l = namedoneto.(2, (:i, :j, :k, :l))
23+
a1 = randn(i, j)
24+
a2 = randn(j, k)
25+
a3 = randn(k, l)
2526
l1, l2, l3 = lazy.((a1, a2, a3))
2627
for li in (l1, l2, l3)
2728
@test li isa LazyNamedDimsArray
@@ -81,7 +82,8 @@ using WrappedUnions: unwrap
8182
@test AbstractTrees.children(l) == [l1 * l2, l3]
8283
@test AbstractTrees.nodevalue(l) *
8384
@test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
84-
@test sprint(show, MIME"text/plain"(), l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
85+
@test sprint(show, MIME"text/plain"(), l) ==
86+
"named(Base.OneTo(2), :i)×named(Base.OneTo(2), :l) LazyNamedDimsArray{Float64, …}:\n(([:i, :j] * [:j, :k]) * [:k, :l])"
8587
@test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])"
8688
@test sprint(print_tree, l) ==
8789
"(([:i, :j] * [:j, :k]) * [:k, :l])\n├─ ([:i, :j] * [:j, :k])\n│ ├─ [:i, :j]\n│ └─ [:j, :k]\n└─ [:k, :l]\n"
@@ -102,7 +104,8 @@ using WrappedUnions: unwrap
102104
@test arguments(ex) == [a1 * a2, a3]
103105
@test operation(ex) *
104106
@test sprint(show, ex) == "((a1 * a2) * a3)"
105-
@test sprint(show, MIME"text/plain"(), ex) == "((a1 * a2) * a3)"
107+
@test sprint(show, MIME"text/plain"(), ex) ==
108+
"0-dimensional LazyNamedDimsArray{Any, …}:\n((a1 * a2) * a3)"
106109
end
107110

108111
@testset "substitute" begin

0 commit comments

Comments
 (0)