Skip to content

Commit a7e368b

Browse files
committed
Define substitute
1 parent 0c19c8b commit a7e368b

File tree

3 files changed

+102
-16
lines changed

3 files changed

+102
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.8"
4+
version = "0.1.9"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/lazynameddimsarrays.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ function Base.show(io::IO, m::Mul)
3939
print(io, "(", join(args, " $(operation(m)) "), ")")
4040
return nothing
4141
end
42+
function Base.hash(m::Mul, h::UInt64)
43+
h = hash(:Mul, h)
44+
for arg in arguments(m)
45+
h = hash(arg, h)
46+
end
47+
return h
48+
end
49+
function map_arguments(f, m::Mul)
50+
return Mul(map(f, arguments(m)))
51+
end
4252

4353
@wrapped struct LazyNamedDimsArray{
4454
T, A <: AbstractNamedDimsArray{T},
@@ -65,6 +75,18 @@ function NamedDimsArrays.dename(a::LazyNamedDimsArray)
6575
end
6676
end
6777

78+
function getindex_lazy(a::AbstractArray, I...)
79+
u = unwrap(a)
80+
if !iscall(u)
81+
return u[I...]
82+
else
83+
return error("Indexing into expression not supported.")
84+
end
85+
end
86+
function Base.getindex(a::LazyNamedDimsArray, I::Int...)
87+
return getindex_lazy(a, I...)
88+
end
89+
6890
function TermInterface.arguments(a::LazyNamedDimsArray)
6991
u = unwrap(a)
7092
if !iscall(u)
@@ -158,6 +180,49 @@ function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
158180
end
159181
end
160182

183+
# Defined to avoid type piracy.
184+
# TODO: Define a proper hash function
185+
# in NamedDimsArrays.jl, maybe one that is
186+
# independent of the order of dimensions.
187+
function _hash(a::NamedDimsArray, h::UInt64)
188+
h = hash(:NamedDimsArray, h)
189+
h = hash(dename(a), h)
190+
for i in inds(a)
191+
h = hash(i, h)
192+
end
193+
return h
194+
end
195+
function _hash(x, h::UInt64)
196+
return hash(x, h)
197+
end
198+
function Base.hash(a::LazyNamedDimsArray, h::UInt64)
199+
h = hash(:LazyNamedDimsArray, h)
200+
# Use `_hash`, which defines a custom hash for NamedDimsArray.
201+
return _hash(unwrap(a), h)
202+
end
203+
204+
generic_map(f, v) = map(f, v)
205+
generic_map(f, v::AbstractDict) = Dict(eachindex(v) .=> map(f, values(v)))
206+
generic_map(f, v::AbstractSet) = Set([f(x) for x in v])
207+
function map_arguments(f, a::LazyNamedDimsArray)
208+
u = unwrap(a)
209+
if !iscall(u)
210+
return error("No arguments to map.")
211+
elseif ismul(u)
212+
return LazyNamedDimsArray(map_arguments(f, u))
213+
else
214+
return error("Variant not supported.")
215+
end
216+
end
217+
function substitute(a::LazyNamedDimsArray, substitutions::AbstractDict)
218+
haskey(substitutions, a) && return substitutions[a]
219+
!iscall(a) && return a
220+
return map_arguments(arg -> substitute(arg, substitutions), a)
221+
end
222+
function substitute(a::LazyNamedDimsArray, substitutions)
223+
return substitute(a, Dict(substitutions))
224+
end
225+
161226
function printnode(io::IO, a::LazyNamedDimsArray)
162227
return printnode(io, unwrap(a))
163228
end
@@ -280,6 +345,17 @@ Base.size(a::SymbolicArray) = length.(axes(a))
280345
function Base.:(==)(a::SymbolicArray, b::SymbolicArray)
281346
return symname(a) == symname(b) && axes(a) == axes(b)
282347
end
348+
function Base.hash(a::SymbolicArray, h::UInt64)
349+
h = hash(:SymbolicArray, h)
350+
h = hash(symname(a), h)
351+
return hash(size(a), h)
352+
end
353+
function Base.getindex(a::SymbolicArray, I...)
354+
return error("Indexing into SymbolicArray not supported.")
355+
end
356+
function Base.setindex!(a::SymbolicArray, value, I...)
357+
return error("Indexing into SymbolicArray not supported.")
358+
end
283359
function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray)
284360
Base.summary(io, a)
285361
println(io, ":")

test/test_lazynameddimsarrays.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using AbstractTrees: AbstractTrees, print_tree, printnode
22
using Base.Broadcast: materialize
33
using ITensorNetworksNext.LazyNamedDimsArrays:
4-
LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, symnameddims
5-
using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims
4+
LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, substitute, symnameddims
5+
using NamedDimsArrays: NamedDimsArray, @names, dename, dimnames, inds, nameddims
66
using TermInterface:
77
arguments,
88
arity,
@@ -88,22 +88,32 @@ using WrappedUnions: unwrap
8888
end
8989

9090
@testset "symnameddims" begin
91-
a = symnameddims(:a)
92-
b = symnameddims(:b)
93-
c = symnameddims(:c)
94-
@test a isa LazyNamedDimsArray
95-
@test unwrap(a) isa NamedDimsArray
96-
@test dename(a) isa SymbolicArray
97-
@test dename(unwrap(a)) isa SymbolicArray
98-
@test dename(unwrap(a)) == SymbolicArray(:a)
91+
a1, a2, a3 = symnameddims.((:a1, :a2, :a3))
92+
@test a1 isa LazyNamedDimsArray
93+
@test unwrap(a1) isa NamedDimsArray
94+
@test dename(a1) isa SymbolicArray
95+
@test dename(unwrap(a1)) isa SymbolicArray
96+
@test dename(unwrap(a1)) == SymbolicArray(:a1)
9997
@test inds(a) == ()
100-
@test dimnames(a) == ()
98+
@test dimnames(a1) == ()
10199

102-
ex = a * b * c
100+
ex = a1 * a2 * a3
103101
@test copy(ex) == ex
104-
@test arguments(ex) == [a * b, c]
102+
@test arguments(ex) == [a1 * a2, a3]
105103
@test operation(ex) *
106-
@test sprint(show, ex) == "((a * b) * c)"
107-
@test sprint(show, MIME"text/plain"(), ex) == "((a * b) * c)"
104+
@test sprint(show, ex) == "((a1 * a2) * a3)"
105+
@test sprint(show, MIME"text/plain"(), ex) == "((a1 * a2) * a3)"
106+
end
107+
108+
@testset "substitute" begin
109+
s = symnameddims.((:a1, :a2, :a3))
110+
i = @names i[1:4]
111+
a = (randn(2, 2)[i[1], i[2]], randn(2, 2)[i[2], i[3]], randn(2, 2)[i[3], i[4]])
112+
l = lazy.(a)
113+
114+
seq = s[1] * (s[2] * s[3])
115+
net = substitute(seq, s .=> l)
116+
@test net == l[1] * (l[2] * l[3])
117+
@test arguments(net) == [l[1], l[2] * l[3]]
108118
end
109119
end

0 commit comments

Comments
 (0)