Skip to content

Commit 0fbf838

Browse files
committed
Reorganize code
1 parent 638b94b commit 0fbf838

File tree

5 files changed

+67
-73
lines changed

5 files changed

+67
-73
lines changed

src/ITensorNetworksNext.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module ITensorNetworksNext
22

3-
include("symbolicarrays.jl")
43
include("lazynameddimsarrays.jl")
54
include("abstracttensornetwork.jl")
65
include("tensornetwork.jl")

src/lazynameddimsarrays.jl

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,8 @@ using NamedDimsArrays:
1010
dename,
1111
dimnames,
1212
inds
13-
using ..SymbolicArrays: SymbolicArrays, SymbolicArray
1413
using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments
1514

16-
const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
17-
NamedDimsArray{T, N, Parent, DimNames}
18-
function symnameddims(name)
19-
return lazy(NamedDimsArray(SymbolicArray(name), ()))
20-
end
21-
function printnode(io::IO, a::SymbolicNamedDimsArray)
22-
print(io, SymbolicArrays.name(dename(a)))
23-
print(io, "[", join(dimnames(a), ","), "]")
24-
return nothing
25-
end
26-
function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
27-
return issetequal(inds(a), inds(b)) && dename(a) == dename(b)
28-
end
29-
function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
30-
return lazy(a) * lazy(b)
31-
end
32-
function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray)
33-
return lazy(a) * b
34-
end
35-
function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray)
36-
return a * lazy(b)
37-
end
38-
3915
# Custom version of `AbstractTrees.printnode` to
4016
# avoid type piracy when overloading on `AbstractNamedDimsArray`.
4117
printnode(io::IO, x) = AbstractTrees.printnode(io, x)
@@ -281,4 +257,67 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
281257
return -a
282258
end
283259

260+
struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N}
261+
name::Name
262+
axes::Axes
263+
function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T}
264+
N = length(ax)
265+
return new{T, N, typeof(name), typeof(ax)}(name, ax)
266+
end
267+
end
268+
function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}})
269+
return SymbolicArray{Any}(name, ax)
270+
end
271+
function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T}
272+
return SymbolicArray{T}(name, ax)
273+
end
274+
function SymbolicArray(name, ax::AbstractUnitRange...)
275+
return SymbolicArray{Any}(name, ax)
276+
end
277+
symname(a::SymbolicArray) = getfield(a, :name)
278+
Base.axes(a::SymbolicArray) = getfield(a, :axes)
279+
Base.size(a::SymbolicArray) = length.(axes(a))
280+
function Base.:(==)(a::SymbolicArray, b::SymbolicArray)
281+
return symname(a) == symname(b) && axes(a) == axes(b)
282+
end
283+
function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray)
284+
Base.summary(io, a)
285+
println(io, ":")
286+
print(io, repr(symname(a)))
287+
return nothing
288+
end
289+
function Base.show(io::IO, a::SymbolicArray)
290+
print(io, "SymbolicArray(", symname(a), ", ", size(a), ")")
291+
return nothing
292+
end
293+
using AbstractTrees: AbstractTrees
294+
function AbstractTrees.printnode(io::IO, a::SymbolicArray)
295+
print(io, repr(symname(a)))
296+
return nothing
297+
end
298+
const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
299+
NamedDimsArray{T, N, Parent, DimNames}
300+
function symnameddims(name)
301+
return lazy(NamedDimsArray(SymbolicArray(name), ()))
302+
end
303+
function printnode(io::IO, a::SymbolicNamedDimsArray)
304+
print(io, symname(dename(a)))
305+
if ndims(a) > 0
306+
print(io, "[", join(dimnames(a), ","), "]")
307+
end
308+
return nothing
309+
end
310+
function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
311+
return issetequal(inds(a), inds(b)) && dename(a) == dename(b)
312+
end
313+
function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
314+
return lazy(a) * lazy(b)
315+
end
316+
function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray)
317+
return lazy(a) * b
318+
end
319+
function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray)
320+
return a * lazy(b)
321+
end
322+
284323
end

src/symbolicarrays.jl

Lines changed: 0 additions & 44 deletions
This file was deleted.

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1414
WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44"
1515

1616
[compat]
17+
AbstractTrees = "0.4.5"
1718
Aqua = "0.8.14"
1819
Dictionaries = "0.4.5"
1920
Graphs = "1.13.1"

test/test_lazynameddimsarrays.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
using AbstractTrees: AbstractTrees, print_tree, printnode
22
using Base.Broadcast: materialize
33
using ITensorNetworksNext.LazyNamedDimsArrays:
4-
LazyNamedDimsArray, Mul, ismul, lazy, symnameddims
5-
using ITensorNetworksNext.SymbolicArrays: SymbolicArray
4+
LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, symnameddims
65
using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims
76
using TermInterface:
87
arguments,
@@ -104,7 +103,7 @@ using WrappedUnions: unwrap
104103
@test copy(ex) == ex
105104
@test arguments(ex) == [a * b, c]
106105
@test operation(ex) *
107-
@test sprint(show, ex) == "((a[] * b[]) * c[])"
108-
@test sprint(show, MIME"text/plain"(), ex) == "((a[] * b[]) * c[])"
106+
@test sprint(show, ex) == "((a * b) * c)"
107+
@test sprint(show, MIME"text/plain"(), ex) == "((a * b) * c)"
109108
end
110109
end

0 commit comments

Comments
 (0)