Skip to content

Commit 399e9e3

Browse files
committed
Remove cirq tags in parallelize and deal with tags when lowering to squin (#489)
Closes #488 and #375 .
1 parent 009a671 commit 399e9e3

File tree

5 files changed

+23
-9
lines changed

5 files changed

+23
-9
lines changed

src/bloqade/cirq_utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .parallelize import (
33
transpile as transpile,
44
parallelize as parallelize,
5-
no_similarity as no_similarity,
5+
remove_tags as remove_tags,
66
auto_similarity as auto_similarity,
77
block_similarity as block_similarity,
88
moment_similarity as moment_similarity,

src/bloqade/cirq_utils/parallelize.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def auto_similarity(
136136
return cirq.Circuit(flattened_circuit), weights
137137

138138

139-
def no_similarity(circuit: cirq.Circuit) -> cirq.Circuit:
139+
def remove_tags(circuit: cirq.Circuit) -> cirq.Circuit:
140140
"""
141141
Removes all tags from the circuit
142142
@@ -146,10 +146,11 @@ def no_similarity(circuit: cirq.Circuit) -> cirq.Circuit:
146146
Returns:
147147
[0] - cirq.Circuit - the circuit with all tags removed.
148148
"""
149-
new_moments = []
150-
for moment in circuit.moments:
151-
new_moments.append([gate.untagged for gate in moment.operations])
152-
return cirq.Circuit(new_moments)
149+
150+
def remove_tag(op: cirq.Operation, _):
151+
return op.untagged
152+
153+
return cirq.map_operations(circuit, remove_tag)
153154

154155

155156
def to_dag_circuit(circuit: cirq.Circuit, can_reorder=None) -> nx.DiGraph:
@@ -399,4 +400,6 @@ def parallelize(
399400
)
400401
# Convert the epochs to a cirq circuit.
401402
moments = map(cirq.Moment, epochs)
402-
return cirq.Circuit(moments)
403+
circuit = cirq.Circuit(moments)
404+
405+
return remove_tags(circuit)

src/bloqade/squin/cirq/lowering.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ def visit_GateOperation(
144144
qbits = self.lower_qubit_getindices(state, node.qubits)
145145
return state.current_frame.push(qubit.Apply(operator=op_, qubits=qbits))
146146

147+
def visit_TaggedOperation(
148+
self, state: lowering.State[CirqNode], node: cirq.TaggedOperation
149+
):
150+
state.lower(node.untagged)
151+
147152
def lower_measurement(
148153
self, state: lowering.State[CirqNode], node: cirq.GateOperation
149154
):

test/cirq_utils/test_parallelize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from bloqade.cirq_utils import (
88
parallelize,
9-
no_similarity,
9+
remove_tags,
1010
block_similarity,
1111
moment_similarity,
1212
)
@@ -28,7 +28,7 @@ def test1():
2828
circuit_m, _ = moment_similarity(circuit, weight=1.0)
2929
# print(circuit_m)
3030
circuit_b, _ = block_similarity(circuit, weight=1.0, block_id=1)
31-
circuit_m2 = no_similarity(circuit_m)
31+
circuit_m2 = remove_tags(circuit_m)
3232
print(circuit_m2)
3333
circuit2 = parallelize(circuit)
3434
# print(circuit2)

test/squin/cirq/test_cirq_to_squin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ def basic_circuit():
3434
)
3535

3636

37+
def tagged_circuit():
38+
q = cirq.LineQubit.range(2)
39+
return cirq.Circuit(cirq.H(q[0]).with_tags("FOO"), cirq.CX(*q).with_tags("BAR"))
40+
41+
3742
def controlled_gates():
3843
q0 = cirq.NamedQubit("q0")
3944
q1 = cirq.NamedQubit("q1")
@@ -166,6 +171,7 @@ def nested_circuit():
166171
"circuit_f",
167172
[
168173
basic_circuit,
174+
tagged_circuit,
169175
controlled_gates,
170176
parity_gate_circuit,
171177
phased_gates,

0 commit comments

Comments
 (0)