Skip to content

Commit c13c1c6

Browse files
authored
fix: Expand controlled operators in to_unitary (#1144)
1 parent 2a395d2 commit c13c1c6

File tree

3 files changed

+112
-15
lines changed

3 files changed

+112
-15
lines changed

src/braket/circuits/unitary_calculation.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections.abc import Iterable
1515

1616
import numpy as np
17-
from braket.default_simulator.linalg_utils import multiply_matrix
17+
from braket.default_simulator.linalg_utils import controlled_matrix, multiply_matrix
1818
from scipy.linalg import fractional_matrix_power
1919

2020
from braket.circuits.compiler_directive import CompilerDirective
@@ -63,19 +63,19 @@ def calculate_unitary_big_endian(
6363
raise TypeError("Only Gate operators are supported to build the unitary")
6464

6565
base_gate_matrix = instruction.operator.to_matrix()
66-
if int(instruction.power) == instruction.power:
67-
gate_matrix = np.linalg.matrix_power(base_gate_matrix, int(instruction.power))
68-
else:
69-
gate_matrix = fractional_matrix_power(base_gate_matrix, instruction.power)
7066

71-
gate_matrix = np.asarray(gate_matrix, dtype=complex)
67+
gate_matrix = (
68+
np.linalg.matrix_power(base_gate_matrix, int(instruction.power))
69+
if int(instruction.power) == instruction.power
70+
else fractional_matrix_power(base_gate_matrix, instruction.power)
71+
)
72+
target = tuple(index_substitutions[qubit] for qubit in instruction.target)
73+
control = tuple(index_substitutions[qubit] for qubit in instruction.control)
7274

7375
unitary = multiply_matrix(
7476
unitary,
75-
gate_matrix,
76-
tuple(index_substitutions[qubit] for qubit in instruction.target),
77-
controls=instruction.control,
78-
control_state=instruction.control_state,
77+
controlled_matrix(np.asarray(gate_matrix, dtype=complex), instruction.control_state),
78+
control + target,
7979
)
8080

8181
return unitary.reshape(rank, rank)

test/integ_tests/test_reservation_arn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_create_task_via_reservation_arn_on_simulator(reservation_arn):
6363
def test_create_job_with_decorator_via_invalid_reservation_arn(reservation_arn):
6464
if AwsDevice(Devices.IQM.Garnet).status == "ONLINE":
6565
with pytest.raises(ClientError, match="Reservation arn is invalid"):
66+
6667
@hybrid_job(
6768
device=Devices.IQM.Garnet,
6869
reservation_arn=reservation_arn,

test/unit_tests/braket/circuits/test_circuit.py

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,7 +2594,7 @@ def test_to_unitary_with_global_phase():
25942594
),
25952595
),
25962596
(
2597-
Circuit().x(0, control=1),
2597+
Circuit().x(3, control=7),
25982598
np.array(
25992599
[
26002600
[1.0, 0.0, 0.0, 0.0],
@@ -2630,7 +2630,7 @@ def test_to_unitary_with_global_phase():
26302630
),
26312631
),
26322632
(
2633-
Circuit().ccnot(1, 2, 0),
2633+
Circuit().ccnot(3, 6, 1),
26342634
np.array(
26352635
[
26362636
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
@@ -2646,7 +2646,7 @@ def test_to_unitary_with_global_phase():
26462646
),
26472647
),
26482648
(
2649-
Circuit().ccnot(2, 1, 0),
2649+
Circuit().ccnot(6, 3, 1),
26502650
np.array(
26512651
[
26522652
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
@@ -2662,7 +2662,7 @@ def test_to_unitary_with_global_phase():
26622662
),
26632663
),
26642664
(
2665-
Circuit().ccnot(0, 2, 1),
2665+
Circuit().ccnot(1, 6, 3),
26662666
np.array(
26672667
[
26682668
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
@@ -2678,7 +2678,7 @@ def test_to_unitary_with_global_phase():
26782678
),
26792679
),
26802680
(
2681-
Circuit().ccnot(2, 0, 1),
2681+
Circuit().ccnot(6, 1, 3),
26822682
np.array(
26832683
[
26842684
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
@@ -2693,6 +2693,102 @@ def test_to_unitary_with_global_phase():
26932693
dtype=complex,
26942694
),
26952695
),
2696+
(
2697+
Circuit().cnot([1, 6], 3),
2698+
np.array(
2699+
[
2700+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2701+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2702+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2703+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
2704+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
2705+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
2706+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
2707+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
2708+
],
2709+
dtype=complex,
2710+
),
2711+
),
2712+
(
2713+
Circuit().cnot([6, 1], 3),
2714+
np.array(
2715+
[
2716+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2717+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2718+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2719+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
2720+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
2721+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
2722+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
2723+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
2724+
],
2725+
dtype=complex,
2726+
),
2727+
),
2728+
(
2729+
Circuit().x(3, control=[6, 1]),
2730+
np.array(
2731+
[
2732+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2733+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2734+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2735+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
2736+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
2737+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
2738+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
2739+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
2740+
],
2741+
dtype=complex,
2742+
),
2743+
),
2744+
(
2745+
Circuit().x(3, control=[1, 6]),
2746+
np.array(
2747+
[
2748+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2749+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2750+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2751+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
2752+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
2753+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
2754+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
2755+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
2756+
],
2757+
dtype=complex,
2758+
),
2759+
),
2760+
(
2761+
Circuit().i(3).cnot(6, 1),
2762+
np.array(
2763+
[
2764+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2765+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
2766+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2767+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
2768+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
2769+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2770+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
2771+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
2772+
],
2773+
dtype=complex,
2774+
),
2775+
),
2776+
(
2777+
Circuit().i(3).x(1, control=[6]),
2778+
np.array(
2779+
[
2780+
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2781+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
2782+
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2783+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
2784+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
2785+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
2786+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
2787+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
2788+
],
2789+
dtype=complex,
2790+
),
2791+
),
26962792
(
26972793
Circuit().s(0).v(1).cnot(0, 1).cnot(2, 1),
26982794
np.dot(

0 commit comments

Comments
 (0)