Skip to content

Commit 514b695

Browse files
authored
feature: add barrier instruction (#1118)
1 parent b98a75e commit 514b695

File tree

11 files changed

+556
-101
lines changed

11 files changed

+556
-101
lines changed

src/braket/circuits/circuit.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,29 @@ def add_verbatim_box(
729729
self._has_compiler_directives = True
730730
return self
731731

732+
def barrier(self, target: QubitSetInput | None = None) -> Circuit:
733+
"""Add a barrier compiler directive to the circuit.
734+
735+
Args:
736+
target (QubitSetInput | None): Target qubits for the barrier.
737+
If None, applies to all qubits in the circuit.
738+
739+
Returns:
740+
Circuit: self
741+
742+
Examples:
743+
>>> circ = Circuit().h(0).barrier([0, 1]).cnot(0, 1)
744+
>>> circ = Circuit().h(0).h(1).barrier() # barrier on all qubits
745+
"""
746+
target_qubits = self.qubits if target is None else QubitSet(target)
747+
748+
if target_qubits:
749+
self.add_instruction(
750+
Instruction(compiler_directives.Barrier(list(target_qubits)), target=target_qubits)
751+
)
752+
self._has_compiler_directives = True
753+
return self
754+
732755
def _add_measure(self, target_qubits: QubitSetInput) -> None:
733756
"""Adds a measure instruction to the the circuit
734757

src/braket/circuits/compiler_directives.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import braket.ir.jaqcd as ir
1717

1818
from braket.circuits.compiler_directive import CompilerDirective
19+
from braket.circuits.serialization import IRType, SerializationProperties
20+
from braket.registers.qubit_set import QubitSet
1921

2022

2123
class StartVerbatimBox(CompilerDirective):
@@ -60,3 +62,36 @@ def _to_jaqcd(self, *args, **kwargs) -> Any:
6062

6163
def _to_openqasm(self) -> str:
6264
return "}"
65+
66+
67+
class Barrier(CompilerDirective):
68+
"""Barrier compiler directive."""
69+
70+
def __init__(self, qubit_indices: list[int]):
71+
super().__init__(["||"])
72+
self._qubit_indices = qubit_indices
73+
74+
@property
75+
def qubit_indices(self) -> list[int]:
76+
return self._qubit_indices
77+
78+
@property
79+
def qubit_count(self) -> int:
80+
return len(self._qubit_indices)
81+
82+
def _to_jaqcd(self) -> Any:
83+
raise NotImplementedError("Barrier is not supported in JAQCD")
84+
85+
def to_ir(
86+
self,
87+
target: QubitSet | None,
88+
ir_type: IRType,
89+
serialization_properties: SerializationProperties | None = None,
90+
**kwargs,
91+
) -> Any:
92+
if ir_type.name == "OPENQASM":
93+
if target:
94+
qubits = ", ".join(serialization_properties.format_target(int(q)) for q in target)
95+
return f"barrier {qubits};"
96+
return "barrier;"
97+
return super().to_ir(target, ir_type, serialization_properties, **kwargs)

src/braket/circuits/moments.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,17 @@ def add(self, instructions: Iterable[Instruction] | Instruction, noise_index: in
174174
def _add(self, instruction: Instruction, noise_index: int = 0) -> None:
175175
operator = instruction.operator
176176
if isinstance(operator, CompilerDirective):
177-
time = self._update_qubit_times(self._qubits)
178-
self._moments[MomentsKey(time, None, MomentType.COMPILER_DIRECTIVE, 0)] = instruction
177+
qubit_range = instruction.target.union(instruction.control or QubitSet())
178+
time = self._handle_compiler_directive(operator, qubit_range)
179+
# For barriers without qubits, use empty qubit set for the key
180+
key_qubits = (
181+
QubitSet() if operator.name == "Barrier" and not qubit_range else qubit_range
182+
)
183+
self._moments[MomentsKey(time, key_qubits, MomentType.COMPILER_DIRECTIVE, 0)] = (
184+
instruction
185+
)
186+
self._qubits.update(qubit_range)
179187
self._depth = time + 1
180-
self._time_all_qubits = time
181188
elif isinstance(operator, Noise):
182189
self.add_noise(instruction)
183190
elif isinstance(operator, Gate) and operator.name == "GPhase":
@@ -282,6 +289,17 @@ def sort_moments(self) -> None:
282289

283290
self._moments = sorted_moment
284291

292+
def _handle_compiler_directive(self, operator: CompilerDirective, qubit_range: QubitSet) -> int:
293+
"""Handle compiler directive and return the time slot."""
294+
if operator.name == "Barrier" and not qubit_range:
295+
time = self._get_qubit_times(self._qubits) + 1
296+
self._time_all_qubits = time
297+
else:
298+
time = self._update_qubit_times(qubit_range or self._qubits)
299+
if operator.name != "Barrier":
300+
self._time_all_qubits = time
301+
return time
302+
285303
def _max_time_for_qubit(self, qubit: Qubit) -> int:
286304
# -1 if qubit is unoccupied because the first instruction will have an index of 0
287305
return self._max_times.get(qubit, -1)
@@ -307,7 +325,7 @@ def values(self) -> ValuesView[Instruction]:
307325
self.sort_moments()
308326
return self._moments.values()
309327

310-
def get(self, key: MomentsKey, default: Any | None = None) -> Instruction:
328+
def get(self, key: MomentsKey, default: Any | None = None) -> Instruction | Any | None:
311329
"""Get the instruction in self by key.
312330
313331
Args:

src/braket/circuits/text_diagram_builders/ascii_circuit_diagram.py

Lines changed: 138 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,125 @@ def _duplicate_time_at_bottom(cls, lines: str) -> None:
7070
# duplicate times after an empty line
7171
lines.append(lines[0])
7272

73+
@classmethod
74+
def _process_item_properties(
75+
cls, item: Instruction | ResultType, circuit_qubits: QubitSet
76+
) -> tuple[QubitSet, QubitSet, QubitSet, QubitSet, list[str], dict | None]:
77+
"""Extract properties from an item, keeping original logic structure."""
78+
if isinstance(item, ResultType) and not item.target:
79+
target_qubits = circuit_qubits
80+
control_qubits = QubitSet()
81+
target_and_control = target_qubits.union(control_qubits)
82+
qubits = circuit_qubits
83+
ascii_symbols = [item.ascii_symbols[0]] * len(circuit_qubits)
84+
map_control_qubit_states = None
85+
elif isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective):
86+
if item.operator.name == "Barrier":
87+
target_qubits = item.target
88+
if not target_qubits:
89+
# Barrier without qubits - single barrier across all qubits
90+
target_qubits = circuit_qubits
91+
qubits = circuit_qubits
92+
ascii_symbols = [item.ascii_symbols[0]] * len(circuit_qubits)
93+
else:
94+
# Barrier with specific qubits
95+
qubits = target_qubits
96+
ascii_symbols = [item.ascii_symbols[0]] * len(target_qubits)
97+
target_and_control = target_qubits
98+
else:
99+
target_qubits = circuit_qubits
100+
control_qubits = QubitSet()
101+
target_and_control = target_qubits.union(control_qubits)
102+
qubits = circuit_qubits
103+
ascii_symbol = item.ascii_symbols[0]
104+
marker = "*" * len(ascii_symbol)
105+
num_after = len(circuit_qubits) - 1
106+
after = ["|"] * (num_after - 1) + ([marker] if num_after else [])
107+
ascii_symbols = [ascii_symbol, *after]
108+
control_qubits = QubitSet()
109+
map_control_qubit_states = None
110+
elif (
111+
isinstance(item, Instruction)
112+
and isinstance(item.operator, Gate)
113+
and item.operator.name == "GPhase"
114+
):
115+
target_qubits = circuit_qubits
116+
control_qubits = QubitSet()
117+
target_and_control = QubitSet()
118+
qubits = circuit_qubits
119+
ascii_symbols = cls._qubit_line_character() * len(circuit_qubits)
120+
map_control_qubit_states = None
121+
else:
122+
if isinstance(item.target, list):
123+
target_qubits = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet())
124+
else:
125+
target_qubits = item.target
126+
control_qubits = getattr(item, "control", QubitSet())
127+
control_state = getattr(item, "control_state", "1" * len(control_qubits))
128+
map_control_qubit_states = dict(zip(control_qubits, control_state, strict=True))
129+
target_and_control = target_qubits.union(control_qubits)
130+
qubits = QubitSet(range(min(target_and_control), max(target_and_control) + 1))
131+
ascii_symbols = item.ascii_symbols
132+
133+
return (
134+
target_qubits,
135+
control_qubits,
136+
target_and_control,
137+
qubits,
138+
ascii_symbols,
139+
map_control_qubit_states,
140+
)
141+
142+
@classmethod
143+
def _update_qubit_symbols_and_connections(
144+
cls,
145+
item: Instruction | ResultType,
146+
qubit: int,
147+
target_qubits: QubitSet,
148+
control_qubits: QubitSet,
149+
target_and_control: QubitSet,
150+
ascii_symbols: list[str],
151+
symbols: dict,
152+
connections: dict,
153+
map_control_qubit_states: dict | None,
154+
) -> None:
155+
"""Update symbols and connections for a qubit, keeping original logic."""
156+
# Determine if the qubit is part of the item or in the middle of a
157+
# multi qubit item.
158+
if qubit in target_qubits:
159+
item_qubit_index = next(index for index, q in enumerate(target_qubits) if q == qubit)
160+
power_string = (
161+
f"^{power}"
162+
if (
163+
(power := getattr(item, "power", 1)) != 1
164+
# this has the limitation of not printing the power
165+
# when a user has a gate genuinely named C, but
166+
# is necessary to enable proper printing of custom
167+
# gates with built-in control qubits
168+
and ascii_symbols[item_qubit_index] != "C"
169+
)
170+
else ""
171+
)
172+
idx = item_qubit_index
173+
symbols[qubit] = (
174+
f"({ascii_symbols[idx]}{power_string})" if power_string else ascii_symbols[idx]
175+
)
176+
elif qubit in control_qubits:
177+
symbols[qubit] = "C" if map_control_qubit_states[qubit] else "N"
178+
else:
179+
symbols[qubit] = "|"
180+
181+
# Set the margin to be a connector if not on the first qubit
182+
if target_and_control and qubit != min(target_and_control):
183+
is_barrier = (
184+
isinstance(item, Instruction)
185+
and isinstance(item.operator, CompilerDirective)
186+
and item.operator.name == "Barrier"
187+
)
188+
# Add vertical lines for non-barriers or global barriers (no target)
189+
if not is_barrier or not item.target:
190+
connections[qubit] = "above"
191+
73192
@classmethod
74193
def _create_diagram_column(
75194
cls,
@@ -91,78 +210,27 @@ def _create_diagram_column(
91210
connections = dict.fromkeys(circuit_qubits, "none")
92211

93212
for item in items:
94-
if isinstance(item, ResultType) and not item.target:
95-
target_qubits = circuit_qubits
96-
control_qubits = QubitSet()
97-
target_and_control = target_qubits.union(control_qubits)
98-
qubits = circuit_qubits
99-
ascii_symbols = [item.ascii_symbols[0]] * len(circuit_qubits)
100-
elif isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective):
101-
target_qubits = circuit_qubits
102-
control_qubits = QubitSet()
103-
target_and_control = target_qubits.union(control_qubits)
104-
qubits = circuit_qubits
105-
ascii_symbol = item.ascii_symbols[0]
106-
marker = "*" * len(ascii_symbol)
107-
num_after = len(circuit_qubits) - 1
108-
after = ["|"] * (num_after - 1) + ([marker] if num_after else [])
109-
ascii_symbols = [ascii_symbol, *after]
110-
elif (
111-
isinstance(item, Instruction)
112-
and isinstance(item.operator, Gate)
113-
and item.operator.name == "GPhase"
114-
):
115-
target_qubits = circuit_qubits
116-
control_qubits = QubitSet()
117-
target_and_control = QubitSet()
118-
qubits = circuit_qubits
119-
ascii_symbols = cls._qubit_line_character() * len(circuit_qubits)
120-
else:
121-
if isinstance(item.target, list):
122-
target_qubits = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet())
123-
else:
124-
target_qubits = item.target
125-
control_qubits = getattr(item, "control", QubitSet())
126-
control_state = getattr(item, "control_state", "1" * len(control_qubits))
127-
map_control_qubit_states = dict(zip(control_qubits, control_state, strict=True))
128-
129-
target_and_control = target_qubits.union(control_qubits)
130-
qubits = QubitSet(range(min(target_and_control), max(target_and_control) + 1))
131-
132-
ascii_symbols = item.ascii_symbols
213+
(
214+
target_qubits,
215+
control_qubits,
216+
target_and_control,
217+
qubits,
218+
ascii_symbols,
219+
map_control_qubit_states,
220+
) = cls._process_item_properties(item, circuit_qubits)
133221

134222
for qubit in qubits:
135-
# Determine if the qubit is part of the item or in the middle of a
136-
# multi qubit item.
137-
if qubit in target_qubits:
138-
item_qubit_index = [ # noqa: RUF015
139-
index for index, q in enumerate(target_qubits) if q == qubit
140-
][0]
141-
power_string = (
142-
f"^{power}"
143-
if (
144-
(power := getattr(item, "power", 1)) != 1
145-
# this has the limitation of not printing the power
146-
# when a user has a gate genuinely named C, but
147-
# is necessary to enable proper printing of custom
148-
# gates with built-in control qubits
149-
and ascii_symbols[item_qubit_index] != "C"
150-
)
151-
else ""
152-
)
153-
symbols[qubit] = (
154-
f"({ascii_symbols[item_qubit_index]}{power_string})"
155-
if power_string
156-
else ascii_symbols[item_qubit_index]
157-
)
158-
elif qubit in control_qubits:
159-
symbols[qubit] = "C" if map_control_qubit_states[qubit] else "N"
160-
else:
161-
symbols[qubit] = "|"
162-
163-
# Set the margin to be a connector if not on the first qubit
164-
if target_and_control and qubit != min(target_and_control):
165-
connections[qubit] = "above"
223+
cls._update_qubit_symbols_and_connections(
224+
item,
225+
qubit,
226+
target_qubits,
227+
control_qubits,
228+
target_and_control,
229+
ascii_symbols,
230+
symbols,
231+
connections,
232+
map_control_qubit_states,
233+
)
166234

167235
return cls._create_output(symbols, connections, circuit_qubits, global_phase)
168236

0 commit comments

Comments
 (0)