diff --git a/src/braket/circuits/circuit.py b/src/braket/circuits/circuit.py index 9268abad7..7d024342e 100644 --- a/src/braket/circuits/circuit.py +++ b/src/braket/circuits/circuit.py @@ -729,6 +729,29 @@ def add_verbatim_box( self._has_compiler_directives = True return self + def barrier(self, target: QubitSetInput | None = None) -> Circuit: + """Add a barrier compiler directive to the circuit. + + Args: + target (QubitSetInput | None): Target qubits for the barrier. + If None, applies to all qubits in the circuit. + + Returns: + Circuit: self + + Examples: + >>> circ = Circuit().h(0).barrier([0, 1]).cnot(0, 1) + >>> circ = Circuit().h(0).h(1).barrier() # barrier on all qubits + """ + target_qubits = self.qubits if target is None else QubitSet(target) + + if target_qubits: + self.add_instruction( + Instruction(compiler_directives.Barrier(list(target_qubits)), target=target_qubits) + ) + self._has_compiler_directives = True + return self + def _add_measure(self, target_qubits: QubitSetInput) -> None: """Adds a measure instruction to the the circuit diff --git a/src/braket/circuits/compiler_directives.py b/src/braket/circuits/compiler_directives.py index 0cd325776..329dcc3f6 100644 --- a/src/braket/circuits/compiler_directives.py +++ b/src/braket/circuits/compiler_directives.py @@ -16,6 +16,8 @@ import braket.ir.jaqcd as ir from braket.circuits.compiler_directive import CompilerDirective +from braket.circuits.serialization import IRType, SerializationProperties +from braket.registers.qubit_set import QubitSet class StartVerbatimBox(CompilerDirective): @@ -60,3 +62,36 @@ def _to_jaqcd(self, *args, **kwargs) -> Any: def _to_openqasm(self) -> str: return "}" + + +class Barrier(CompilerDirective): + """Barrier compiler directive.""" + + def __init__(self, qubit_indices: list[int]): + super().__init__(["||"]) + self._qubit_indices = qubit_indices + + @property + def qubit_indices(self) -> list[int]: + return self._qubit_indices + + @property + def qubit_count(self) -> int: + return len(self._qubit_indices) + + def _to_jaqcd(self) -> Any: + raise NotImplementedError("Barrier is not supported in JAQCD") + + def to_ir( + self, + target: QubitSet | None, + ir_type: IRType, + serialization_properties: SerializationProperties | None = None, + **kwargs, + ) -> Any: + if ir_type.name == "OPENQASM": + if target: + qubits = ", ".join(serialization_properties.format_target(int(q)) for q in target) + return f"barrier {qubits};" + return "barrier;" + return super().to_ir(target, ir_type, serialization_properties, **kwargs) diff --git a/src/braket/circuits/moments.py b/src/braket/circuits/moments.py index 10f236366..9121125b7 100644 --- a/src/braket/circuits/moments.py +++ b/src/braket/circuits/moments.py @@ -174,10 +174,17 @@ def add(self, instructions: Iterable[Instruction] | Instruction, noise_index: in def _add(self, instruction: Instruction, noise_index: int = 0) -> None: operator = instruction.operator if isinstance(operator, CompilerDirective): - time = self._update_qubit_times(self._qubits) - self._moments[MomentsKey(time, None, MomentType.COMPILER_DIRECTIVE, 0)] = instruction + qubit_range = instruction.target.union(instruction.control or QubitSet()) + time = self._handle_compiler_directive(operator, qubit_range) + # For barriers without qubits, use empty qubit set for the key + key_qubits = ( + QubitSet() if operator.name == "Barrier" and not qubit_range else qubit_range + ) + self._moments[MomentsKey(time, key_qubits, MomentType.COMPILER_DIRECTIVE, 0)] = ( + instruction + ) + self._qubits.update(qubit_range) self._depth = time + 1 - self._time_all_qubits = time elif isinstance(operator, Noise): self.add_noise(instruction) elif isinstance(operator, Gate) and operator.name == "GPhase": @@ -282,6 +289,17 @@ def sort_moments(self) -> None: self._moments = sorted_moment + def _handle_compiler_directive(self, operator: CompilerDirective, qubit_range: QubitSet) -> int: + """Handle compiler directive and return the time slot.""" + if operator.name == "Barrier" and not qubit_range: + time = self._get_qubit_times(self._qubits) + 1 + self._time_all_qubits = time + else: + time = self._update_qubit_times(qubit_range or self._qubits) + if operator.name != "Barrier": + self._time_all_qubits = time + return time + def _max_time_for_qubit(self, qubit: Qubit) -> int: # -1 if qubit is unoccupied because the first instruction will have an index of 0 return self._max_times.get(qubit, -1) @@ -307,7 +325,7 @@ def values(self) -> ValuesView[Instruction]: self.sort_moments() return self._moments.values() - def get(self, key: MomentsKey, default: Any | None = None) -> Instruction: + def get(self, key: MomentsKey, default: Any | None = None) -> Instruction | Any | None: """Get the instruction in self by key. Args: diff --git a/src/braket/circuits/text_diagram_builders/ascii_circuit_diagram.py b/src/braket/circuits/text_diagram_builders/ascii_circuit_diagram.py index 04bd7f059..91b4c1abf 100644 --- a/src/braket/circuits/text_diagram_builders/ascii_circuit_diagram.py +++ b/src/braket/circuits/text_diagram_builders/ascii_circuit_diagram.py @@ -70,6 +70,125 @@ def _duplicate_time_at_bottom(cls, lines: str) -> None: # duplicate times after an empty line lines.append(lines[0]) + @classmethod + def _process_item_properties( + cls, item: Instruction | ResultType, circuit_qubits: QubitSet + ) -> tuple[QubitSet, QubitSet, QubitSet, QubitSet, list[str], dict | None]: + """Extract properties from an item, keeping original logic structure.""" + if isinstance(item, ResultType) and not item.target: + target_qubits = circuit_qubits + control_qubits = QubitSet() + target_and_control = target_qubits.union(control_qubits) + qubits = circuit_qubits + ascii_symbols = [item.ascii_symbols[0]] * len(circuit_qubits) + map_control_qubit_states = None + elif isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective): + if item.operator.name == "Barrier": + target_qubits = item.target + if not target_qubits: + # Barrier without qubits - single barrier across all qubits + target_qubits = circuit_qubits + qubits = circuit_qubits + ascii_symbols = [item.ascii_symbols[0]] * len(circuit_qubits) + else: + # Barrier with specific qubits + qubits = target_qubits + ascii_symbols = [item.ascii_symbols[0]] * len(target_qubits) + target_and_control = target_qubits + else: + target_qubits = circuit_qubits + control_qubits = QubitSet() + target_and_control = target_qubits.union(control_qubits) + qubits = circuit_qubits + ascii_symbol = item.ascii_symbols[0] + marker = "*" * len(ascii_symbol) + num_after = len(circuit_qubits) - 1 + after = ["|"] * (num_after - 1) + ([marker] if num_after else []) + ascii_symbols = [ascii_symbol, *after] + control_qubits = QubitSet() + map_control_qubit_states = None + elif ( + isinstance(item, Instruction) + and isinstance(item.operator, Gate) + and item.operator.name == "GPhase" + ): + target_qubits = circuit_qubits + control_qubits = QubitSet() + target_and_control = QubitSet() + qubits = circuit_qubits + ascii_symbols = cls._qubit_line_character() * len(circuit_qubits) + map_control_qubit_states = None + else: + if isinstance(item.target, list): + target_qubits = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet()) + else: + target_qubits = item.target + control_qubits = getattr(item, "control", QubitSet()) + control_state = getattr(item, "control_state", "1" * len(control_qubits)) + map_control_qubit_states = dict(zip(control_qubits, control_state, strict=True)) + target_and_control = target_qubits.union(control_qubits) + qubits = QubitSet(range(min(target_and_control), max(target_and_control) + 1)) + ascii_symbols = item.ascii_symbols + + return ( + target_qubits, + control_qubits, + target_and_control, + qubits, + ascii_symbols, + map_control_qubit_states, + ) + + @classmethod + def _update_qubit_symbols_and_connections( + cls, + item: Instruction | ResultType, + qubit: int, + target_qubits: QubitSet, + control_qubits: QubitSet, + target_and_control: QubitSet, + ascii_symbols: list[str], + symbols: dict, + connections: dict, + map_control_qubit_states: dict | None, + ) -> None: + """Update symbols and connections for a qubit, keeping original logic.""" + # Determine if the qubit is part of the item or in the middle of a + # multi qubit item. + if qubit in target_qubits: + item_qubit_index = next(index for index, q in enumerate(target_qubits) if q == qubit) + power_string = ( + f"^{power}" + if ( + (power := getattr(item, "power", 1)) != 1 + # this has the limitation of not printing the power + # when a user has a gate genuinely named C, but + # is necessary to enable proper printing of custom + # gates with built-in control qubits + and ascii_symbols[item_qubit_index] != "C" + ) + else "" + ) + idx = item_qubit_index + symbols[qubit] = ( + f"({ascii_symbols[idx]}{power_string})" if power_string else ascii_symbols[idx] + ) + elif qubit in control_qubits: + symbols[qubit] = "C" if map_control_qubit_states[qubit] else "N" + else: + symbols[qubit] = "|" + + # Set the margin to be a connector if not on the first qubit + if target_and_control and qubit != min(target_and_control): + is_barrier = ( + isinstance(item, Instruction) + and isinstance(item.operator, CompilerDirective) + and item.operator.name == "Barrier" + ) + # Add vertical lines for non-barriers or global barriers (no target) + if not is_barrier or not item.target: + connections[qubit] = "above" + @classmethod def _create_diagram_column( cls, @@ -91,78 +210,27 @@ def _create_diagram_column( connections = dict.fromkeys(circuit_qubits, "none") for item in items: - if isinstance(item, ResultType) and not item.target: - target_qubits = circuit_qubits - control_qubits = QubitSet() - target_and_control = target_qubits.union(control_qubits) - qubits = circuit_qubits - ascii_symbols = [item.ascii_symbols[0]] * len(circuit_qubits) - elif isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective): - target_qubits = circuit_qubits - control_qubits = QubitSet() - target_and_control = target_qubits.union(control_qubits) - qubits = circuit_qubits - ascii_symbol = item.ascii_symbols[0] - marker = "*" * len(ascii_symbol) - num_after = len(circuit_qubits) - 1 - after = ["|"] * (num_after - 1) + ([marker] if num_after else []) - ascii_symbols = [ascii_symbol, *after] - elif ( - isinstance(item, Instruction) - and isinstance(item.operator, Gate) - and item.operator.name == "GPhase" - ): - target_qubits = circuit_qubits - control_qubits = QubitSet() - target_and_control = QubitSet() - qubits = circuit_qubits - ascii_symbols = cls._qubit_line_character() * len(circuit_qubits) - else: - if isinstance(item.target, list): - target_qubits = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet()) - else: - target_qubits = item.target - control_qubits = getattr(item, "control", QubitSet()) - control_state = getattr(item, "control_state", "1" * len(control_qubits)) - map_control_qubit_states = dict(zip(control_qubits, control_state, strict=True)) - - target_and_control = target_qubits.union(control_qubits) - qubits = QubitSet(range(min(target_and_control), max(target_and_control) + 1)) - - ascii_symbols = item.ascii_symbols + ( + target_qubits, + control_qubits, + target_and_control, + qubits, + ascii_symbols, + map_control_qubit_states, + ) = cls._process_item_properties(item, circuit_qubits) for qubit in qubits: - # Determine if the qubit is part of the item or in the middle of a - # multi qubit item. - if qubit in target_qubits: - item_qubit_index = [ # noqa: RUF015 - index for index, q in enumerate(target_qubits) if q == qubit - ][0] - power_string = ( - f"^{power}" - if ( - (power := getattr(item, "power", 1)) != 1 - # this has the limitation of not printing the power - # when a user has a gate genuinely named C, but - # is necessary to enable proper printing of custom - # gates with built-in control qubits - and ascii_symbols[item_qubit_index] != "C" - ) - else "" - ) - symbols[qubit] = ( - f"({ascii_symbols[item_qubit_index]}{power_string})" - if power_string - else ascii_symbols[item_qubit_index] - ) - elif qubit in control_qubits: - symbols[qubit] = "C" if map_control_qubit_states[qubit] else "N" - else: - symbols[qubit] = "|" - - # Set the margin to be a connector if not on the first qubit - if target_and_control and qubit != min(target_and_control): - connections[qubit] = "above" + cls._update_qubit_symbols_and_connections( + item, + qubit, + target_qubits, + control_qubits, + target_and_control, + ascii_symbols, + symbols, + connections, + map_control_qubit_states, + ) return cls._create_output(symbols, connections, circuit_qubits, global_phase) diff --git a/src/braket/circuits/text_diagram_builders/text_circuit_diagram_utils.py b/src/braket/circuits/text_diagram_builders/text_circuit_diagram_utils.py index 1e3dcae6b..2d1872cb9 100644 --- a/src/braket/circuits/text_diagram_builders/text_circuit_diagram_utils.py +++ b/src/braket/circuits/text_diagram_builders/text_circuit_diagram_utils.py @@ -112,6 +112,44 @@ def _compute_moment_global_phase( return global_phase + moment_phase if global_phase is not None else None +def _get_qubit_range_for_item(item: Instruction | ResultType, circuit_qubits: QubitSet) -> QubitSet: + """Get the qubit range for a given item.""" + if ( + isinstance(item, Instruction) + and isinstance(item.operator, Gate) + and item.operator.name == "GPhase" + ): + return QubitSet() + + if isinstance(item, ResultType) and not item.target: + return circuit_qubits + + if isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective): + return _get_compiler_directive_qubit_range(item, circuit_qubits) + + return _get_standard_qubit_range(item) + + +def _get_compiler_directive_qubit_range(item: Instruction, circuit_qubits: QubitSet) -> QubitSet: + """Get qubit range for compiler directive instructions.""" + if item.operator.name == "Barrier": + if not item.target or len(item.target) == 0: + return circuit_qubits + return item.target + return circuit_qubits + + +def _get_standard_qubit_range(item: Instruction | ResultType) -> QubitSet: + """Get qubit range for standard instructions and result types.""" + if isinstance(item.target, list): + target = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet()) + else: + target = item.target + control = getattr(item, "control", QubitSet()) + target_and_control = target.union(control) + return QubitSet(range(min(target_and_control), max(target_and_control) + 1)) + + def _group_items( circuit_qubits: QubitSet, items: list[Instruction | ResultType], @@ -128,33 +166,13 @@ def _group_items( """ groupings = [] for item in items: - # Can only print QuantumOperator and CompilerDirective operators for instructions at - # the moment + # Can only print QuantumOperator and CompilerDirective operators for instructions if isinstance(item, Instruction) and not isinstance( item.operator, CompilerDirective | QuantumOperator ): continue - # As a zero-qubit gate, GPhase can be grouped with anything. We set qubit_range - # to an empty list and we just add it to the first group below. - if ( - isinstance(item, Instruction) - and isinstance(item.operator, Gate) - and item.operator.name == "GPhase" - ): - qubit_range = QubitSet() - elif (isinstance(item, ResultType) and not item.target) or ( - isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective) - ): - qubit_range = circuit_qubits - else: - if isinstance(item.target, list): - target = reduce(QubitSet.union, map(QubitSet, item.target), QubitSet()) - else: - target = item.target - control = getattr(item, "control", QubitSet()) - target_and_control = target.union(control) - qubit_range = QubitSet(range(min(target_and_control), max(target_and_control) + 1)) + qubit_range = _get_qubit_range_for_item(item, circuit_qubits) found_grouping = False for group in groupings: diff --git a/src/braket/circuits/text_diagram_builders/unicode_circuit_diagram.py b/src/braket/circuits/text_diagram_builders/unicode_circuit_diagram.py index 4b49bf90a..e7d00f675 100644 --- a/src/braket/circuits/text_diagram_builders/unicode_circuit_diagram.py +++ b/src/braket/circuits/text_diagram_builders/unicode_circuit_diagram.py @@ -140,14 +140,34 @@ def _build_parameters( ) -> tuple: map_control_qubit_states = {} - if (isinstance(item, ResultType) and not item.target) or ( - isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective) - ): + if isinstance(item, ResultType) and not item.target: target_qubits = circuit_qubits control_qubits = QubitSet() qubits = circuit_qubits ascii_symbols = [item.ascii_symbols[0]] * len(qubits) cls._update_connections(qubits, connections) + elif isinstance(item, Instruction) and isinstance(item.operator, CompilerDirective): + if item.operator.name == "Barrier": + if not item.target: + # Barrier without qubits - single barrier across all qubits WITH connections + target_qubits = circuit_qubits + qubits = circuit_qubits + ascii_symbols = [item.ascii_symbols[0]] * len(circuit_qubits) + cls._update_connections(circuit_qubits, connections) + else: + # Barrier with specific qubits - only add connections for global barriers + target_qubits = item.target + qubits = target_qubits + ascii_symbols = [item.ascii_symbols[0]] * len(target_qubits) + # Specific barriers get no vertical lines + # (Global barriers are handled above with no target) + control_qubits = QubitSet() + else: + target_qubits = circuit_qubits + control_qubits = QubitSet() + qubits = circuit_qubits + ascii_symbols = [item.ascii_symbols[0]] * len(qubits) + cls._update_connections(qubits, connections) elif ( isinstance(item, Instruction) and isinstance(item.operator, Gate) @@ -209,12 +229,12 @@ def _draw_symbol( """ top = "" bottom = "" - if symbol in {"C", "N", "SWAP"}: + if symbol in {"C", "N", "SWAP", "||"}: if connection in {"above", "both"}: top = _fill_symbol(cls._vertical_delimiter(), " ") if connection in {"below", "both"}: bottom = _fill_symbol(cls._vertical_delimiter(), " ") - new_symbol = {"C": "●", "N": "◯", "SWAP": "x"} + new_symbol = {"C": "●", "N": "◯", "SWAP": "x", "||": "▒"} # replace SWAP by x # the size of the moment remains as if there was a box with 4 characters inside symbol = _fill_symbol(new_symbol[symbol], cls._qubit_line_character()) diff --git a/test/unit_tests/braket/circuits/test_ascii_circuit_diagram.py b/test/unit_tests/braket/circuits/test_ascii_circuit_diagram.py index 901bbbb9b..e86f42425 100644 --- a/test/unit_tests/braket/circuits/test_ascii_circuit_diagram.py +++ b/test/unit_tests/braket/circuits/test_ascii_circuit_diagram.py @@ -14,6 +14,7 @@ import numpy as np import pytest +from braket.circuits.compiler_directives import Barrier from braket.circuits import ( AsciiCircuitDiagram, Circuit, @@ -312,6 +313,42 @@ def test_overlapping_qubits(): _assert_correct_diagram(circ, expected) +def test_barrier_single_qubit(): + circ = Circuit().x(0).x(1).barrier(target=[0]).h(2) + expected = ( + "T : |0|1 |", + " ", + "q0 : -X-||-", + " ", + "q1 : -X----", + " ", + "q2 : -H----", + "", + "T : |0|1 |", + ) + _assert_correct_diagram(circ, expected) + + +def test_barrier_global_with_vertical_lines(): + from braket.circuits.compiler_directives import Barrier + + circ = Circuit().x(0).x(1) + circ.add_instruction(Instruction(Barrier([]), [])) + circ.h(2) + expected = ( + "T : |0|1 |2|", + " ", + "q0 : -X-||---", + " | ", + "q1 : -X-||---", + " | ", + "q2 : ---||-H-", + "", + "T : |0|1 |2|", + ) + _assert_correct_diagram(circ, expected) + + def test_overlapping_qubits_angled_gates(): circ = Circuit().zz(0, 2, 0.15).x(control=1, target=3).h(0) expected = ( @@ -956,3 +993,67 @@ def test_measure_with_readout_noise(): "T : |0| 1 |2|", ) _assert_correct_diagram(circ, expected) + + +def test_barrier_circuit_visualization_without_other_gates(): + circ = Circuit().barrier(target=[0, 100]) + expected = ( + "T : |0 |", + " ", + "q0 : -||-", + " ", + "q100 : -||-", + "", + "T : |0 |", + ) + _assert_correct_diagram(circ, expected) + + +def test_barrier_circuit_visualization_with_other_gates(): + circ = Circuit().x(0).barrier(target=[0, 100]).h(3) + expected = ( + "T : |0|1 |", + " ", + "q0 : -X-||-", + " ", + "q3 : -H----", + " ", + "q100 : ---||-", + "", + "T : |0|1 |", + ) + _assert_correct_diagram(circ, expected) + + +def test_barrier_multiple_qubits_with_gates(): + circ = Circuit().x(0).x(1).barrier(target=[0, 1]).h(0).h(2) + expected = ( + "T : |0|1 |2|", + " ", + "q0 : -X-||-H-", + " ", + "q1 : -X-||---", + " ", + "q2 : -H------", + "", + "T : |0|1 |2|", + ) + _assert_correct_diagram(circ, expected) + + +def test_barrier_global_with_vertical_lines(): + circ = Circuit().x(0).x(1) + circ.add_instruction(Instruction(Barrier([]), [])) + circ.h(2) + expected = ( + "T : |0|1 |2|", + " ", + "q0 : -X-||---", + " | ", + "q1 : -X-||---", + " | ", + "q2 : ---||-H-", + "", + "T : |0|1 |2|", + ) + _assert_correct_diagram(circ, expected) diff --git a/test/unit_tests/braket/circuits/test_circuit.py b/test/unit_tests/braket/circuits/test_circuit.py index a76794bec..fd18f2a6e 100644 --- a/test/unit_tests/braket/circuits/test_circuit.py +++ b/test/unit_tests/braket/circuits/test_circuit.py @@ -3589,3 +3589,50 @@ def test_from_ir_round_trip_transformation_with_targeted_measurements(): assert Circuit.from_ir(ir) == Circuit.from_ir(circuit.to_ir("OPENQASM")) assert circuit.to_ir("OPENQASM") == Circuit.from_ir(ir).to_ir("OPENQASM") + + +def test_barrier_specific_qubits(): + circ = Circuit().barrier([0, 1, 2]) + assert len(circ.instructions) == 1 + instr = circ.instructions[0] + assert isinstance(instr.operator, compiler_directives.Barrier) + assert instr.target == QubitSet([0, 1, 2]) + assert instr.operator.qubit_indices == [0, 1, 2] + assert circ.qubits_frozen is True + + +def test_barrier_all_qubits(): + circ = Circuit().h(0).h(1).barrier() + assert len(circ.instructions) == 3 + barrier_instr = circ.instructions[2] + assert isinstance(barrier_instr.operator, compiler_directives.Barrier) + assert barrier_instr.target == QubitSet([0, 1]) + + +def test_barrier_empty_circuit(): + circ = Circuit().barrier() + assert len(circ.instructions) == 0 # No barrier added to empty circuit + + +def test_barrier_none_target(): + circ = Circuit().h(0).h(2).barrier(None) + barrier_instr = circ.instructions[2] + assert barrier_instr.target == QubitSet([0, 2]) + + +def test_barrier_openqasm_export_specific_qubits(): + circ = Circuit().h(0).barrier([0, 1]).cnot(0, 1) + qasm = circ.to_ir(IRType.OPENQASM).source + assert "barrier q[0], q[1];" in qasm + + +def test_barrier_openqasm_export_all_qubits(): + circ = Circuit().h(0).h(1).barrier().cnot(0, 1) + qasm = circ.to_ir(IRType.OPENQASM).source + assert "barrier q[0], q[1];" in qasm + + +def test_barrier_jaqcd_export_fails(): + circ = Circuit().h(0).barrier([0, 1]) + with pytest.raises(NotImplementedError, match="Barrier is not supported in JAQCD"): + circ.to_ir(IRType.JAQCD) diff --git a/test/unit_tests/braket/circuits/test_compiler_directives.py b/test/unit_tests/braket/circuits/test_compiler_directives.py index a05e4c20a..ae1e4905f 100644 --- a/test/unit_tests/braket/circuits/test_compiler_directives.py +++ b/test/unit_tests/braket/circuits/test_compiler_directives.py @@ -17,6 +17,7 @@ from braket.circuits import compiler_directives from braket.circuits.compiler_directive import CompilerDirective from braket.circuits.serialization import IRType +from braket.circuits.serialization import OpenQASMSerializationProperties, QubitReferenceType @pytest.mark.parametrize( @@ -49,3 +50,21 @@ def test_verbatim(testclass, irclass, openqasm_str, counterpart): assert directive is not op assert directive != CompilerDirective(ascii_symbols=["foo"]) assert directive != "not a directive" + + +def test_barrier(): + barrier = compiler_directives.Barrier([0, 1, 2]) + assert barrier.qubit_indices == [0, 1, 2] + assert barrier.qubit_count == 3 + assert barrier.ascii_symbols == ("||",) + assert repr(barrier) == "Barrier" + + with pytest.raises(NotImplementedError, match="Barrier is not supported in JAQCD"): + barrier._to_jaqcd() + + props = OpenQASMSerializationProperties(qubit_reference_type=QubitReferenceType.VIRTUAL) + result = barrier.to_ir([0, 1, 2], IRType.OPENQASM, props) + assert result == "barrier q[0], q[1], q[2];" + + result = barrier.to_ir([], IRType.OPENQASM, props) + assert result == "barrier;" diff --git a/test/unit_tests/braket/circuits/test_moments.py b/test/unit_tests/braket/circuits/test_moments.py index ed45b0aa3..f2930e6ce 100644 --- a/test/unit_tests/braket/circuits/test_moments.py +++ b/test/unit_tests/braket/circuits/test_moments.py @@ -16,6 +16,7 @@ import pytest from braket.circuits import Gate, Instruction, Moments, MomentsKey, QubitSet +from braket.circuits.compiler_directives import Barrier def cnot(q1, q2): @@ -185,3 +186,17 @@ def test_repr(moments): def test_str(moments): assert str(moments) == str(OrderedDict(moments)) + + +def test_barrier_with_qubits(): + """Test barrier with specific qubits.""" + moments = Moments([h(0), h(1)]) + moments.add(Instruction(Barrier([0]), [0])) + assert moments.depth == 2 + + +def test_barrier_without_qubits(): + """Test barrier without qubits (global).""" + moments = Moments([h(0), h(1)]) + moments.add(Instruction(Barrier([]), [])) + assert moments.depth == 2 diff --git a/test/unit_tests/braket/circuits/test_unicode_circuit_diagram.py b/test/unit_tests/braket/circuits/test_unicode_circuit_diagram.py index a0288ddda..0cc31ecc9 100644 --- a/test/unit_tests/braket/circuits/test_unicode_circuit_diagram.py +++ b/test/unit_tests/braket/circuits/test_unicode_circuit_diagram.py @@ -26,6 +26,8 @@ ) from braket.pulse import Frame, Port, PulseSequence +from braket.circuits.compiler_directives import Barrier + def _assert_correct_diagram(circ, expected): assert UnicodeCircuitDiagram.build_diagram(circ) == "\n".join(expected) @@ -1119,3 +1121,92 @@ def test_measure_with_readout_noise(): "T : │ 0 │ 1 │ 2 │", ) _assert_correct_diagram(circ, expected) + + +def test_barrier_circuit_visualization_without_other_gates(): + circ = Circuit().barrier(target=[0, 100]) + expected = ( + "T : │ 0 │", + " ", + "q0 : ───▒────", + " ", + " ", + "q100 : ───▒────", + " ", + "T : │ 0 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_barrier_circuit_visualization_with_other_gates(): + circ = Circuit().x(0).barrier(target=[0, 100]).h(3) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ", + "q0 : ─┤ X ├───▒────", + " └───┘ ", + " ┌───┐ ", + "q3 : ─┤ H ├────────", + " └───┘ ", + " ", + "q100 : ─────────▒────", + " ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_barrier_single_qubit(): + circ = Circuit().x(0).x(1).barrier(target=[0]).h(2) + expected = ( + "T : │ 0 │ 1 │", + " ┌───┐ ", + "q0 : ─┤ X ├───▒────", + " └───┘ ", + " ┌───┐ ", + "q1 : ─┤ X ├────────", + " └───┘ ", + " ┌───┐ ", + "q2 : ─┤ H ├────────", + " └───┘ ", + "T : │ 0 │ 1 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_barrier_multiple_qubits_with_gates(): + circ = Circuit().x(0).x(1).barrier(target=[0, 1]).h(0).h(2) + expected = ( + "T : │ 0 │ 1 │ 2 │", + " ┌───┐ ┌───┐ ", + "q0 : ─┤ X ├───▒────┤ H ├─", + " └───┘ └───┘ ", + " ┌───┐ ", + "q1 : ─┤ X ├───▒──────────", + " └───┘ ", + " ┌───┐ ", + "q2 : ─┤ H ├──────────────", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │", + ) + _assert_correct_diagram(circ, expected) + + +def test_barrier_global_with_vertical_lines(): + circ = Circuit().x(0).x(1) + circ.add_instruction(Instruction(Barrier([]), [])) + circ.h(2) + expected = ( + "T : │ 0 │ 1 │ 2 │", + " ┌───┐ ", + "q0 : ─┤ X ├───▒──────────", + " └───┘ │ ", + " ┌───┐ │ ", + "q1 : ─┤ X ├───▒──────────", + " └───┘ │ ", + " │ ┌───┐ ", + "q2 : ─────────▒────┤ H ├─", + " └───┘ ", + "T : │ 0 │ 1 │ 2 │", + ) + _assert_correct_diagram(circ, expected)