Skip to content

Commit 7b14202

Browse files
committed
Stack Tensor Networks using generic_stack
1 parent b040316 commit 7b14202

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

perf/vqe/main.jl

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ using Unitful
1414
include("utils.jl")
1515
include("Circuit.jl")
1616

17-
# numbe of qubits
17+
# number of qubits
1818
N = 20
1919

2020
# number of layers
21-
L = 4
21+
L = 6
2222

2323
# generate parametric circuit
2424
ansatz = efficient_su2(N, L)
@@ -69,28 +69,15 @@ function expectation(params, obs, coef)
6969
U_dagger = adjoint(U)
7070
bra = adjoint(ket)
7171

72-
# rename index labels to avoid conflicts
73-
resetinds!(ket)
74-
resetinds!(U)
75-
resetinds!(obs)
76-
resetinds!(U_dagger)
77-
resetinds!(bra)
78-
79-
# align the indices
80-
@align! outputs(ket) => inputs(U)
81-
@align! outputs(U) => inputs(obs)
82-
@align! outputs(obs) => inputs(U_dagger)
83-
@align! outputs(U_dagger) => inputs(bra)
84-
85-
# construct the tensor network for the expectation value
86-
tn = GenericTensorNetwork()
87-
append!(tn, all_tensors(ket))
88-
append!(tn, all_tensors(U))
89-
append!(tn, all_tensors(obs))
90-
append!(tn, all_tensors(U_dagger))
91-
append!(tn, all_tensors(bra))
92-
93-
res = contract(tn; optimizer=LineGraph())
72+
tn = generic_stack(ket, U, obs, U_dagger, bra)
73+
74+
# print path flops and max rank to consistenly check that the same contraction path is used
75+
# (exponentially big changes can be seen if not)
76+
path = einexpr(tn; optimizer=Greedy())
77+
@info "Contraction path" max_rank = maximum(ndims, Branches(path)) total_flops = mapreduce(
78+
EinExprs.flops, +, Branches(path)
79+
)
80+
res = contract(tn; path)
9481
return real(coef * res[]) # ⟨ψ|U† O U|ψ⟩
9582
end
9683

@@ -123,15 +110,15 @@ results = Vector{Tuple{String,String,T,T,Float64}}()
123110
f_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(; sync=true) expectation(
124111
params_re, observable_re, coef_re
125112
)
126-
b = @benchmark f_xla($params_re, $observable_re, $coef_re) setup = (GC.gc(true))
113+
b = @benchmark $f_xla($params_re, $observable_re, $coef_re) setup = (GC.gc(true))
127114
baseline = median(b).time
128115
push!(
129116
results, ("Primal", "Only XLA", median(b).time * 1.0u"ns", std(b).time * 1.0u"ns", 1.0)
130117
)
131118

132119
## default
133120
f_default = @compile sync = true expectation(params_re, observable_re, coef_re)
134-
b = @benchmark f_default($params_re, $observable_re, $coef_re) setup = (GC.gc(true))
121+
b = @benchmark $f_default($params_re, $observable_re, $coef_re) setup = (GC.gc(true))
135122
push!(
136123
results,
137124
(

perf/vqe/utils.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using YaoBlocks
2+
using Tangles
23

34
# simplified from `qiskit.circuit.library.efficient_su2`
45
function efficient_su2(nqubits, nlayers)
@@ -20,3 +21,27 @@ function efficient_su2(nqubits, nlayers)
2021

2122
return chain(nqubits, gates...)
2223
end
24+
25+
struct StackTag{T}
26+
id::T
27+
t::Int
28+
end
29+
30+
function generic_stack(tns...)
31+
tn = GenericTensorNetwork()
32+
tns = copy.(tns)
33+
34+
for (i, tni) in enumerate(tns)
35+
for ind in inds(tni)
36+
replace!(tni, ind => Index(StackTag(ind.tag, i)))
37+
end
38+
end
39+
40+
append!(tn, all_tensors(tns[1]))
41+
for i in 2:length(tns)
42+
@align! outputs(tns[i - 1]) => inputs(tns[i])
43+
append!(tn, all_tensors(tns[i]))
44+
end
45+
46+
return tn
47+
end

0 commit comments

Comments
 (0)