2020import pennylane as qml
2121import pytest
2222from braket .aws import AwsDevice , AwsDeviceType
23- from braket .circuits import FreeParameter , gates , noises , observables
23+ from braket .circuits import FreeParameter , Noise , gates , noises , observables
2424from braket .circuits .result_types import (
2525 AdjointGradient ,
2626 DensityMatrix ,
@@ -151,12 +151,6 @@ def _aws_device(
151151 (qml .IsingYY , gates .YY , [0 , 1 ], [np .pi ]),
152152 (qml .IsingZZ , gates .ZZ , [0 , 1 ], [np .pi ]),
153153 (qml .AmplitudeDamping , noises .AmplitudeDamping , [0 ], [0.1 ]),
154- (
155- qml .GeneralizedAmplitudeDamping ,
156- noises .GeneralizedAmplitudeDamping ,
157- [0 ],
158- [0.1 , 0.15 ],
159- ),
160154 (qml .PhaseDamping , noises .PhaseDamping , [0 ], [0.1 ]),
161155 (qml .DepolarizingChannel , noises .Depolarizing , [0 ], [0.1 ]),
162156 (qml .BitFlip , noises .BitFlip , [0 ], [0.1 ]),
@@ -315,14 +309,6 @@ def _aws_device(
315309 ["alpha" ],
316310 [FreeParameter ("alpha" )],
317311 ),
318- (
319- qml .GeneralizedAmplitudeDamping ,
320- noises .GeneralizedAmplitudeDamping ,
321- [0 ],
322- [0.1 , 0.15 ],
323- ["p_000" , "p_001" ],
324- [FreeParameter ("p_000" ), FreeParameter ("p_001" )],
325- ),
326312 (qml .PhaseDamping , noises .PhaseDamping , [0 ], [0.1 ], ["a" ], [FreeParameter ("a" )]),
327313 (
328314 qml .DepolarizingChannel ,
@@ -386,7 +372,12 @@ def test_translate_operation(pl_cls, braket_cls, qubits, params):
386372 pl_op = pl_cls (* params , wires = qubits )
387373 braket_gate = braket_cls (* params )
388374 assert translate_operation (pl_op ) == braket_gate
389- if isinstance (pl_op , (GPi , GPi2 , MS , AAMS , PRx )):
375+ if isinstance (braket_gate , (Noise , gates .Unitary )):
376+ assert (
377+ _braket_to_pl [braket_gate .to_ir (qubits ).__class__ .__name__ .lower ().replace ("_" , "" )]
378+ == pl_op .name
379+ )
380+ else :
390381 translated_back = _braket_to_pl [
391382 re .match ("^[a-z0-2]+" , braket_gate .to_ir (qubits , ir_type = IRType .OPENQASM )).group (0 )
392383 ]
@@ -397,11 +388,6 @@ def test_translate_operation(pl_cls, braket_cls, qubits, params):
397388 # Braket MS gets translated to PL AAMS.
398389 else translated_back == "AAMS"
399390 )
400- else :
401- assert (
402- _braket_to_pl [braket_gate .to_ir (qubits ).__class__ .__name__ .lower ().replace ("_" , "" )]
403- == pl_op .name
404- )
405391
406392
407393@pytest .mark .parametrize (
@@ -418,7 +404,12 @@ def test_translate_operation_with_unique_params(
418404 translate_operation (pl_op , use_unique_params = True , param_names = pl_param_names )
419405 == braket_gate
420406 )
421- if isinstance (pl_op , (GPi , GPi2 , MS , AAMS )):
407+ if isinstance (braket_gate , (Noise , gates .Unitary )):
408+ assert (
409+ _braket_to_pl [braket_gate .to_ir (qubits ).__class__ .__name__ .lower ().replace ("_" , "" )]
410+ == pl_op .name
411+ )
412+ else :
422413 translated_back = _braket_to_pl [
423414 re .match ("^[a-z0-2]+" , braket_gate .to_ir (qubits , ir_type = IRType .OPENQASM )).group (0 )
424415 ]
@@ -429,11 +420,32 @@ def test_translate_operation_with_unique_params(
429420 # Braket MS gets translated to PL AAMS.
430421 else translated_back == "AAMS"
431422 )
432- else :
433- assert (
434- _braket_to_pl [braket_gate .to_ir (qubits ).__class__ .__name__ .lower ().replace ("_" , "" )]
435- == pl_op .name
436- )
423+
424+
425+ def test_generalized_amplitude_damping ():
426+ """Tests that GeneralizedAmplitudeDamping is translated correctly"""
427+ qubits = [0 ]
428+ pl_param_names = ["p_000" , "p_001" ]
429+ pl_op = qml .GeneralizedAmplitudeDamping (0.1 , 0.15 , wires = qubits )
430+ braket_op = noises .GeneralizedAmplitudeDamping (0.1 , 0.85 )
431+ assert translate_operation (pl_op ) == braket_op
432+ assert (
433+ _braket_to_pl [braket_op .to_ir (qubits ).__class__ .__name__ .lower ().replace ("_" , "" )]
434+ == pl_op .name
435+ )
436+ braket_op_parametrized = noises .GeneralizedAmplitudeDamping (
437+ FreeParameter ("p_000" ), 1 - FreeParameter ("p_001" )
438+ )
439+ assert (
440+ translate_operation (pl_op , use_unique_params = True , param_names = pl_param_names )
441+ == braket_op_parametrized
442+ )
443+ assert (
444+ _braket_to_pl [
445+ braket_op_parametrized .to_ir (qubits ).__class__ .__name__ .lower ().replace ("_" , "" )
446+ ]
447+ == pl_op .name
448+ )
437449
438450
439451def amplitude (p , t ):
0 commit comments