Skip to content

Commit 35345b7

Browse files
authored
Some code refactoring (#24)
1 parent 929b62d commit 35345b7

File tree

6 files changed

+98
-64
lines changed

6 files changed

+98
-64
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworksNext"
22
uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/LazyNamedDimsArrays/lazyinterface.jl

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ function walk(opmap, argmap, ex)
1010
if !iscall(ex)
1111
return argmap(ex)
1212
else
13-
return mapfoldl((args...) -> walk(opmap, argmap, args...), opmap(operation(ex)), arguments(ex))
13+
return mapfoldl(opmap(operation(ex)), arguments(ex)) do (args...)
14+
return walk(opmap, argmap, args...)
15+
end
1416
end
1517
end
1618
# Walk the expression `ex`, modifying the
@@ -21,13 +23,22 @@ opwalk(opmap, a) = walk(opmap, identity, a)
2123
argwalk(argmap, a) = walk(identity, argmap, a)
2224

2325
# Generic lazy functionality.
26+
using DerivableInterfaces: AbstractArrayInterface, InterfaceFunction
27+
struct LazyInterface{N} <: AbstractArrayInterface{N} end
28+
LazyInterface() = LazyInterface{Any}()
29+
LazyInterface(::Val{N}) where {N} = LazyInterface{N}()
30+
LazyInterface{M}(::Val{N}) where {M, N} = LazyInterface{N}()
31+
const lazy_interface = LazyInterface()
32+
33+
const maketerm_lazy = lazy_interface(maketerm)
2434
function maketerm_lazy(type::Type, head, args, metadata)
2535
if head *
2636
return type(maketerm(Mul, head, args, metadata))
2737
else
2838
return error("Only mul supported right now.")
2939
end
3040
end
41+
const getindex_lazy = lazy_interface(getindex)
3142
function getindex_lazy(a::AbstractArray, I...)
3243
u = unwrap(a)
3344
if !iscall(u)
@@ -36,6 +47,7 @@ function getindex_lazy(a::AbstractArray, I...)
3647
return error("Indexing into expression not supported.")
3748
end
3849
end
50+
const arguments_lazy = lazy_interface(arguments)
3951
function arguments_lazy(a)
4052
u = unwrap(a)
4153
if !iscall(u)
@@ -46,18 +58,18 @@ function arguments_lazy(a)
4658
return error("Variant not supported.")
4759
end
4860
end
49-
function children_lazy(a)
50-
return arguments(a)
51-
end
52-
function head_lazy(a)
53-
return operation(a)
54-
end
55-
function iscall_lazy(a)
56-
return iscall(unwrap(a))
57-
end
58-
function isexpr_lazy(a)
59-
return iscall(a)
60-
end
61+
using TermInterface: children
62+
const children_lazy = lazy_interface(children)
63+
children_lazy(a) = arguments(a)
64+
using TermInterface: head
65+
const head_lazy = lazy_interface(head)
66+
head_lazy(a) = operation(a)
67+
const iscall_lazy = lazy_interface(iscall)
68+
iscall_lazy(a) = iscall(unwrap(a))
69+
using TermInterface: isexpr
70+
const isexpr_lazy = lazy_interface(isexpr)
71+
isexpr_lazy(a) = iscall(a)
72+
const operation_lazy = lazy_interface(operation)
6173
function operation_lazy(a)
6274
u = unwrap(a)
6375
if !iscall(u)
@@ -68,6 +80,7 @@ function operation_lazy(a)
6880
return error("Variant not supported.")
6981
end
7082
end
83+
const sorted_arguments_lazy = lazy_interface(sorted_arguments)
7184
function sorted_arguments_lazy(a)
7285
u = unwrap(a)
7386
if !iscall(u)
@@ -78,27 +91,35 @@ function sorted_arguments_lazy(a)
7891
return error("Variant not supported.")
7992
end
8093
end
81-
function sorted_children_lazy(a)
82-
return sorted_arguments(a)
83-
end
94+
using TermInterface: sorted_children
95+
const sorted_children_lazy = lazy_interface(sorted_children)
96+
sorted_children_lazy(a) = sorted_arguments(a)
97+
const ismul_lazy = lazy_interface(ismul)
8498
ismul_lazy(a) = ismul(unwrap(a))
99+
using AbstractTrees: AbstractTrees
100+
const abstracttrees_children_lazy = lazy_interface(AbstractTrees.children)
85101
function abstracttrees_children_lazy(a)
86102
if !iscall(a)
87103
return ()
88104
else
89105
return arguments(a)
90106
end
91107
end
108+
using AbstractTrees: nodevalue
109+
const nodevalue_lazy = lazy_interface(nodevalue)
92110
function nodevalue_lazy(a)
93111
if !iscall(a)
94112
return unwrap(a)
95113
else
96114
return operation(a)
97115
end
98116
end
99-
materialize_lazy(a) = argwalk(unwrap, a)
100117
using Base.Broadcast: materialize
118+
const materialize_lazy = lazy_interface(materialize)
119+
materialize_lazy(a) = argwalk(unwrap, a)
120+
const copy_lazy = lazy_interface(copy)
101121
copy_lazy(a) = materialize(a)
122+
const equals_lazy = lazy_interface(==)
102123
function equals_lazy(a1, a2)
103124
u1, u2 = unwrap.((a1, a2))
104125
if !iscall(u1) && !iscall(u2)
@@ -109,6 +130,7 @@ function equals_lazy(a1, a2)
109130
return false
110131
end
111132
end
133+
const isequal_lazy = lazy_interface(isequal)
112134
function isequal_lazy(a1, a2)
113135
u1, u2 = unwrap.((a1, a2))
114136
if !iscall(u1) && !iscall(u2)
@@ -119,11 +141,13 @@ function isequal_lazy(a1, a2)
119141
return false
120142
end
121143
end
144+
const hash_lazy = lazy_interface(hash)
122145
function hash_lazy(a, h::UInt64)
123146
h = hash(Symbol(unspecify_type_parameters(typeof(a))), h)
124147
# Use `_hash`, which defines a custom hash for NamedDimsArray.
125148
return _hash(unwrap(a), h)
126149
end
150+
const map_arguments_lazy = lazy_interface(map_arguments)
127151
function map_arguments_lazy(f, a)
128152
u = unwrap(a)
129153
if !iscall(u)
@@ -134,19 +158,22 @@ function map_arguments_lazy(f, a)
134158
return error("Variant not supported.")
135159
end
136160
end
161+
function substitute end
162+
const substitute_lazy = lazy_interface(substitute)
137163
function substitute_lazy(a, substitutions::AbstractDict)
138164
haskey(substitutions, a) && return substitutions[a]
139165
!iscall(a) && return a
140166
return map_arguments(arg -> substitute(arg, substitutions), a)
141167
end
142-
function substitute_lazy(a, substitutions)
143-
return substitute(a, Dict(substitutions))
144-
end
168+
substitute_lazy(a, substitutions) = substitute(a, Dict(substitutions))
169+
using AbstractTrees: printnode
170+
const printnode_lazy = lazy_interface(printnode)
145171
function printnode_lazy(io, a)
146172
# Use `printnode_nameddims` to avoid type piracy,
147173
# since it overloads on `AbstractNamedDimsArray`.
148174
return printnode_nameddims(io, unwrap(a))
149175
end
176+
const show_lazy = lazy_interface(show)
150177
function show_lazy(io::IO, a)
151178
if !iscall(a)
152179
return show(io, unwrap(a))
@@ -160,9 +187,12 @@ function show_lazy(io::IO, mime::MIME"text/plain", a)
160187
!iscall(a) ? show(io, mime, unwrap(a)) : show(io, a)
161188
return nothing
162189
end
190+
const add_lazy = lazy_interface(+)
163191
add_lazy(a1, a2) = error("Not implemented.")
192+
const sub_lazy = lazy_interface(-)
164193
sub_lazy(a) = error("Not implemented.")
165194
sub_lazy(a1, a2) = error("Not implemented.")
195+
const mul_lazy = lazy_interface(*)
166196
function mul_lazy(a)
167197
u = unwrap(a)
168198
if !iscall(u)
@@ -186,6 +216,7 @@ mul_lazy(a1::Number, a2::Number) = a1 * a2
186216
div_lazy(a1, a2::Number) = error("Not implemented.")
187217

188218
# NamedDimsArrays.jl interface.
219+
const inds_lazy = lazy_interface(inds)
189220
function inds_lazy(a)
190221
u = unwrap(a)
191222
if !iscall(u)
@@ -196,6 +227,7 @@ function inds_lazy(a)
196227
return error("Variant not supported.")
197228
end
198229
end
230+
const dename_lazy = lazy_interface(dename)
199231
function dename_lazy(a)
200232
u = unwrap(a)
201233
if !iscall(u)

src/TensorNetworkGenerators/ising_network.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using DiagonalArrays: DiagonalArray
22
using Graphs: degree, dst, edges, src
3+
using ..ITensorNetworksNext: @preserve_graph
34
using LinearAlgebra: Diagonal, eigen
45
using NamedDimsArrays: apply, dename, inds, operator, randname
56
using NamedGraphs.GraphsExtensions: vertextype
@@ -42,8 +43,8 @@ function ising_network(
4243
deg2 = degree(tn, v2)
4344
m = sqrt_ising_bond(β; J, h, deg1, deg2)
4445
t = operator(m, ((e),), (f(e),))
45-
tn[v1] = apply(t, tn[v1])
46-
tn[v2] = apply(t, tn[v2])
46+
@preserve_graph tn[v1] = apply(t, tn[v1])
47+
@preserve_graph tn[v2] = apply(t, tn[v2])
4748
end
4849
return tn
4950
end

src/contract_network.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ function contraction_order(tn; alg = default_kwargs(contraction_order, tn).alg,
5656
end
5757
# Convert the tensor network to a flat symbolic multiplication expression.
5858
function contraction_order(alg::Algorithm"flat", tn)
59-
syms = [symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)]
6059
# Same as: `reduce((a, b) -> *(a, b; flatten = true), syms)`.
60+
syms = vec([symnameddims(i, Tuple(inds(tn[i]))) for i in keys(tn)])
6161
return lazy(Mul(syms))
6262
end
6363
function contraction_order(alg::Algorithm"left_associative", tn)

test/test_tensornetworkgenerators.jl

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,7 @@ using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges
88
using NamedGraphs.NamedGraphGenerators: named_grid
99
using Test: @test, @testset
1010

11-
module TestUtils
12-
using QuadGK: quadgk
13-
# Exact critical inverse temperature for 2D square lattice Ising model.
14-
βc_2d_ising(elt::Type{<:Number} = Float64) = elt(log(1 + 2) / 2)
15-
# Exact infinite volume free energy density for 2D square lattice Ising model.
16-
function f_2d_ising::Real; J::Real = one(β))
17-
κ = 2sinh(2β * J) / cosh(2β * J)^2
18-
integrand(θ) = log((1 + (abs(1 -* sin(θ))^2))) / 2)
19-
integral, _ = quadgk(integrand, 0, π)
20-
return (-log(2cosh(2β * J)) - (1 / (2π)) * integral) / β
21-
end
22-
function f_1d_ising::Real; J::Real = one(β), h::Real = zero(β))
23-
λ⁺ = exp* J) * (cosh* h) + (sinh* h)^2 + exp(-4β * J)))
24-
return -(log(λ⁺) / β)
25-
end
26-
function f_1d_ising::Real, N::Integer; periodic::Bool = true, kwargs...)
27-
return if periodic
28-
f_1d_ising_periodic(β, N; kwargs...)
29-
else
30-
f_1d_ising_open(β, N; kwargs...)
31-
end
32-
end
33-
function f_1d_ising_periodic::Real, N::Integer; J::Real = one(β), h::Real = zero(β))
34-
r = (sinh* h)^2 + exp(-4β * J))
35-
λ⁺ = exp* J) * (cosh* h) + r)
36-
λ⁻ = exp* J) * (cosh* h) - r)
37-
Z = λ⁺^N + λ⁻^N
38-
return -(log(Z) /* N))
39-
end
40-
function f_1d_ising_open::Real, N::Integer; J::Real = one(β), h::Real = zero(β))
41-
isone(N) && return 2cosh* h)
42-
T = [
43-
exp* (J + h)) exp(-β * J);
44-
exp(-β * J) exp* (J - h));
45-
]
46-
b = [exp* h / 2), exp(-β * h / 2)]
47-
Z = (b' * (T^(N - 1)) * b)[]
48-
return -(log(Z) /* N))
49-
end
50-
end
11+
!@isdefined(TestUtils) && include("utils.jl")
5112

5213
@testset "TensorNetworkGenerators" begin
5314
@testset "Delta Network" begin

test/utils.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
module TestUtils
2+
using QuadGK: quadgk
3+
# Exact critical inverse temperature for 2D square lattice Ising model.
4+
βc_2d_ising(elt::Type{<:Number} = Float64) = elt(log(1 + 2) / 2)
5+
# Exact infinite volume free energy density for 2D square lattice Ising model.
6+
function f_2d_ising::Real; J::Real = one(β))
7+
κ = 2sinh(2β * J) / cosh(2β * J)^2
8+
integrand(θ) = log((1 + (abs(1 -* sin(θ))^2))) / 2)
9+
integral, _ = quadgk(integrand, 0, π)
10+
return (-log(2cosh(2β * J)) - (1 / (2π)) * integral) / β
11+
end
12+
function f_1d_ising::Real; J::Real = one(β), h::Real = zero(β))
13+
λ⁺ = exp* J) * (cosh* h) + (sinh* h)^2 + exp(-4β * J)))
14+
return -(log(λ⁺) / β)
15+
end
16+
function f_1d_ising::Real, N::Integer; periodic::Bool = true, kwargs...)
17+
return if periodic
18+
f_1d_ising_periodic(β, N; kwargs...)
19+
else
20+
f_1d_ising_open(β, N; kwargs...)
21+
end
22+
end
23+
function f_1d_ising_periodic::Real, N::Integer; J::Real = one(β), h::Real = zero(β))
24+
r = (sinh* h)^2 + exp(-4β * J))
25+
λ⁺ = exp* J) * (cosh* h) + r)
26+
λ⁻ = exp* J) * (cosh* h) - r)
27+
Z = λ⁺^N + λ⁻^N
28+
return -(log(Z) /* N))
29+
end
30+
function f_1d_ising_open::Real, N::Integer; J::Real = one(β), h::Real = zero(β))
31+
isone(N) && return 2cosh* h)
32+
T = [
33+
exp* (J + h)) exp(-β * J);
34+
exp(-β * J) exp* (J - h));
35+
]
36+
b = [exp* h / 2), exp(-β * h / 2)]
37+
Z = (b' * (T^(N - 1)) * b)[]
38+
return -(log(Z) /* N))
39+
end
40+
end

0 commit comments

Comments
 (0)