Skip to content

Commit c6e7c9d

Browse files
authored
fix: Use PL GeneralizedAmplitudeDamping convention (#322)
1 parent afe19ed commit c6e7c9d

File tree

5 files changed

+46
-34
lines changed

5 files changed

+46
-34
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
install_requires=[
3030
"amazon-braket-sdk>=1.97.0",
3131
"autoray>=0.6.11",
32-
"pennylane>=0.42.0",
32+
"pennylane>=0.44.0",
3333
],
3434
entry_points={
3535
"pennylane.plugins": [

src/braket/pennylane_plugin/translation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def _(_: qml.AmplitudeDamping, parameters, device=None):
319319
@_translate_operation.register
320320
def _(_: qml.GeneralizedAmplitudeDamping, parameters, device=None):
321321
gamma = parameters[0]
322-
probability = parameters[1]
322+
probability = 1 - parameters[1]
323323
return noises.GeneralizedAmplitudeDamping(probability=probability, gamma=gamma)
324324

325325

test/integ_tests/test_apply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ def test_single_qubit_noise(self, init_state, dm_device, op, prob, tol):
154154
state = init_state(1)
155155
TestHardwareApply.assert_noise_op(op, dev, state, [0], tol, [prob])
156156

157-
@pytest.mark.parametrize("gamma", [0.0, 0.42])
158157
@pytest.mark.parametrize("prob", [0.0, 0.42])
158+
@pytest.mark.parametrize("gamma", [0.0, 0.42])
159159
def test_generalized_amplitude_damping(self, init_state, dm_device, gamma, prob, tol):
160160
"""Test parametrized GeneralizedAmplitudeDamping"""
161161
dev = dm_device(1)

test/unit_tests/test_braket_device.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,17 +239,17 @@ def test_apply_unique_parameters():
239239
qml.RX(np.pi, wires=0),
240240
qml.RY(np.pi, wires=0),
241241
# note the gamma/p ordering doesn't affect the naming of the parameters below.
242-
qml.GeneralizedAmplitudeDamping(gamma=0.1, p=0.9, wires=0),
243-
qml.GeneralizedAmplitudeDamping(p=0.9, gamma=0.1, wires=0),
242+
qml.GeneralizedAmplitudeDamping(gamma=0.1, p=0.2, wires=0),
243+
qml.GeneralizedAmplitudeDamping(p=0.2, gamma=0.1, wires=0),
244244
],
245245
use_unique_params=True,
246246
)
247247
expected = Circuit().h(0).cnot(0, 1).rx(0, FreeParameter("p_0"))
248248
expected = expected.ry(0, FreeParameter("p_1"))
249249

250250
# Right now, the Braket SDK doesn't keep track of noise parameters
251-
expected = expected.generalized_amplitude_damping(0, gamma=0.1, probability=0.9)
252-
expected = expected.generalized_amplitude_damping(0, gamma=0.1, probability=0.9)
251+
expected = expected.generalized_amplitude_damping(0, gamma=0.1, probability=0.8)
252+
expected = expected.generalized_amplitude_damping(0, gamma=0.1, probability=0.8)
253253
assert circuit == expected
254254

255255

test/unit_tests/test_translation.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pennylane as qml
2121
import pytest
2222
from braket.aws import AwsDevice, AwsDeviceType
23-
from braket.circuits import FreeParameter, gates, noises, observables
23+
from braket.circuits import FreeParameter, Noise, gates, noises, observables
2424
from braket.circuits.result_types import (
2525
AdjointGradient,
2626
DensityMatrix,
@@ -151,12 +151,6 @@ def _aws_device(
151151
(qml.IsingYY, gates.YY, [0, 1], [np.pi]),
152152
(qml.IsingZZ, gates.ZZ, [0, 1], [np.pi]),
153153
(qml.AmplitudeDamping, noises.AmplitudeDamping, [0], [0.1]),
154-
(
155-
qml.GeneralizedAmplitudeDamping,
156-
noises.GeneralizedAmplitudeDamping,
157-
[0],
158-
[0.1, 0.15],
159-
),
160154
(qml.PhaseDamping, noises.PhaseDamping, [0], [0.1]),
161155
(qml.DepolarizingChannel, noises.Depolarizing, [0], [0.1]),
162156
(qml.BitFlip, noises.BitFlip, [0], [0.1]),
@@ -315,14 +309,6 @@ def _aws_device(
315309
["alpha"],
316310
[FreeParameter("alpha")],
317311
),
318-
(
319-
qml.GeneralizedAmplitudeDamping,
320-
noises.GeneralizedAmplitudeDamping,
321-
[0],
322-
[0.1, 0.15],
323-
["p_000", "p_001"],
324-
[FreeParameter("p_000"), FreeParameter("p_001")],
325-
),
326312
(qml.PhaseDamping, noises.PhaseDamping, [0], [0.1], ["a"], [FreeParameter("a")]),
327313
(
328314
qml.DepolarizingChannel,
@@ -386,7 +372,12 @@ def test_translate_operation(pl_cls, braket_cls, qubits, params):
386372
pl_op = pl_cls(*params, wires=qubits)
387373
braket_gate = braket_cls(*params)
388374
assert translate_operation(pl_op) == braket_gate
389-
if isinstance(pl_op, (GPi, GPi2, MS, AAMS, PRx)):
375+
if isinstance(braket_gate, (Noise, gates.Unitary)):
376+
assert (
377+
_braket_to_pl[braket_gate.to_ir(qubits).__class__.__name__.lower().replace("_", "")]
378+
== pl_op.name
379+
)
380+
else:
390381
translated_back = _braket_to_pl[
391382
re.match("^[a-z0-2]+", braket_gate.to_ir(qubits, ir_type=IRType.OPENQASM)).group(0)
392383
]
@@ -397,11 +388,6 @@ def test_translate_operation(pl_cls, braket_cls, qubits, params):
397388
# Braket MS gets translated to PL AAMS.
398389
else translated_back == "AAMS"
399390
)
400-
else:
401-
assert (
402-
_braket_to_pl[braket_gate.to_ir(qubits).__class__.__name__.lower().replace("_", "")]
403-
== pl_op.name
404-
)
405391

406392

407393
@pytest.mark.parametrize(
@@ -418,7 +404,12 @@ def test_translate_operation_with_unique_params(
418404
translate_operation(pl_op, use_unique_params=True, param_names=pl_param_names)
419405
== braket_gate
420406
)
421-
if isinstance(pl_op, (GPi, GPi2, MS, AAMS)):
407+
if isinstance(braket_gate, (Noise, gates.Unitary)):
408+
assert (
409+
_braket_to_pl[braket_gate.to_ir(qubits).__class__.__name__.lower().replace("_", "")]
410+
== pl_op.name
411+
)
412+
else:
422413
translated_back = _braket_to_pl[
423414
re.match("^[a-z0-2]+", braket_gate.to_ir(qubits, ir_type=IRType.OPENQASM)).group(0)
424415
]
@@ -429,11 +420,32 @@ def test_translate_operation_with_unique_params(
429420
# Braket MS gets translated to PL AAMS.
430421
else translated_back == "AAMS"
431422
)
432-
else:
433-
assert (
434-
_braket_to_pl[braket_gate.to_ir(qubits).__class__.__name__.lower().replace("_", "")]
435-
== pl_op.name
436-
)
423+
424+
425+
def test_generalized_amplitude_damping():
426+
"""Tests that GeneralizedAmplitudeDamping is translated correctly"""
427+
qubits = [0]
428+
pl_param_names = ["p_000", "p_001"]
429+
pl_op = qml.GeneralizedAmplitudeDamping(0.1, 0.15, wires=qubits)
430+
braket_op = noises.GeneralizedAmplitudeDamping(0.1, 0.85)
431+
assert translate_operation(pl_op) == braket_op
432+
assert (
433+
_braket_to_pl[braket_op.to_ir(qubits).__class__.__name__.lower().replace("_", "")]
434+
== pl_op.name
435+
)
436+
braket_op_parametrized = noises.GeneralizedAmplitudeDamping(
437+
FreeParameter("p_000"), 1 - FreeParameter("p_001")
438+
)
439+
assert (
440+
translate_operation(pl_op, use_unique_params=True, param_names=pl_param_names)
441+
== braket_op_parametrized
442+
)
443+
assert (
444+
_braket_to_pl[
445+
braket_op_parametrized.to_ir(qubits).__class__.__name__.lower().replace("_", "")
446+
]
447+
== pl_op.name
448+
)
437449

438450

439451
def amplitude(p, t):

0 commit comments

Comments
 (0)