Skip to content

Commit b040316

Browse files
committed
Remove random ind names from Circuit
1 parent 051c230 commit b040316

File tree

1 file changed

+44
-13
lines changed

1 file changed

+44
-13
lines changed

perf/vqe/Circuit.jl

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using YaoBlocks
55
using DelegatorTraits
66
import DelegatorTraits: DelegatorTrait
77

8-
struct Circuit <: Tangles.AbstractTensorNetwork
8+
@kwdef struct Circuit <: Tangles.AbstractTensorNetwork
99
tn::Tangles.GenericTensorNetwork
10+
last_t::Dict{Site,Int} = Dict{Site,Int}()
1011
end
1112

1213
DelegatorTrait(::Networks.Network, ::Circuit) = DelegateToField{:tn}()
@@ -15,7 +16,7 @@ DelegatorTrait(::Tangles.TensorNetwork, ::Circuit) = DelegateToField{:tn}()
1516
DelegatorTrait(::Tangles.Pluggable, ::Circuit) = DelegateToField{:tn}()
1617
DelegatorTrait(::Tangles.Lattice, ::Circuit) = DelegateToField{:tn}()
1718

18-
Base.copy(circ::Circuit) = Circuit(copy(circ.tn))
19+
Base.copy(circ::Circuit) = Circuit(copy(circ.tn), Dict{Site,Int}())
1920

2021
function flatten_circuit(x)
2122
if any(i -> i isa ChainBlock, subblocks(x))
@@ -25,12 +26,24 @@ function flatten_circuit(x)
2526
end
2627
end
2728

29+
struct LaneAt{S}
30+
site::S
31+
t::Int
32+
end
33+
34+
function Base.show(io::IO, lane::LaneAt)
35+
print(io, lane.site)
36+
return print(io, "@$(lane.t)")
37+
end
38+
39+
moment(circuit::Circuit, site::Site) = circuit.last_t[site]
40+
2841
"""
2942
Convert a Yao circuit to a Circuit.
3043
"""
3144
function Base.convert(::Type{Circuit}, yaocirc::AbstractBlock)
3245
tn = GenericTensorNetwork()
33-
circuit = Circuit(tn)
46+
circuit = Circuit(tn, Dict{Site,Int}())
3447

3548
for gate in flatten_circuit(yaocirc)
3649
# if gate isa Swap
@@ -45,7 +58,7 @@ function Base.convert(::Type{Circuit}, yaocirc::AbstractBlock)
4558
# NOTE `YaoBlocks.mat` on m-site qubits still returns the operator on the full Hilbert space
4659
m = length(occupied_locs(gate))
4760
operator = if gate isa YaoBlocks.ControlBlock
48-
control((1:(m-1))..., m => content(gate))(m)
61+
control((1:(m - 1))..., m => content(gate))(m)
4962
else
5063
content(gate)
5164
end
@@ -60,27 +73,45 @@ end
6073

6174
function Tangles.addtensor!(circuit::Circuit, tensor::Tensor)
6275
target_plugs = plugs(tensor)
76+
target_plugs_in = filter(isdual, target_plugs)
77+
target_plugs_out = filter(!isdual, target_plugs)
78+
target_sites = unique!(site.(target_plugs))
6379

64-
for plug in filter(isdual, target_plugs) .|> adjoint
80+
# if lane is not present, add an identity gate
81+
for plug in adjoint.(target_plugs_in)
6582
if !hasplug(circuit, plug)
66-
input, out = Index(gensym(:tmp)), Index(gensym(:tmp))
83+
input, out = Index(LaneAt(site(plug), 1)), Index(LaneAt(site(plug), 2))
6784
addtensor!(circuit.tn, Tensor([1 0; 0 1], [input, out]))
6885
setplug!(circuit, input, plug')
6986
setplug!(circuit, out, plug)
87+
circuit.last_t[site(plug)] = 2
7088
end
7189
end
7290

73-
tensor = replace(tensor, [Index(plug"i'") => ind_at(circuit, plug"i") for i in unique(site.(plugs(tensor)))]...)
74-
# for all the normal plugs in the operator
75-
# new_ind = Index(gensym(:tmp)) # Index((; layer=..., site=...))
76-
new_inds = Dict(plug"i" => Index(gensym(:tmp)) for i in unique(site.(plugs(tensor))))
77-
tensor = replace(tensor, [Index(k) => v for (k, v) in new_inds]...)
91+
# align gate tensor with the circuit
92+
tensor = replace(
93+
tensor,
94+
[
95+
Index(plug) => Index(LaneAt(site(plug), moment(circuit, site(plug)))) for
96+
plug in target_plugs_in
97+
]...,
98+
[
99+
Index(plug) => Index(LaneAt(site(plug), moment(circuit, site(plug)) + 1)) for
100+
plug in target_plugs_out
101+
]...,
102+
)
78103

79104
addtensor!(circuit.tn, tensor)
80105

81-
for (plug, new_ind) in new_inds
106+
# update plug tags mapping
107+
for plug in target_plugs_out
82108
unsetplug!(circuit, plug)
83-
setplug!(circuit, new_ind, plug)
109+
setplug!(circuit, Index(LaneAt(site(plug), moment(circuit, site(plug)) + 1)), plug)
110+
end
111+
112+
# update the last_t for each site
113+
for site in target_sites
114+
circuit.last_t[site] = circuit.last_t[site] + 1
84115
end
85116

86117
return circuit

0 commit comments

Comments
 (0)