2323
2424import catalyst
2525from catalyst import qjit
26+ from catalyst .from_plxpr import register_transform
2627
2728pytestmark = pytest .mark .usefixtures ("disable_capture" )
2829
@@ -1048,6 +1049,30 @@ def circuit(x: float):
10481049
10491050 assert jnp .allclose (circuit (0.1 ), capture_result )
10501051
1052+ @pytest .mark .usefixtures ("use_capture" )
1053+ def test_pass_with_options (self , backend ):
1054+ """Test the integration for a circuit with a pass that takes in options."""
1055+
1056+ @qml .transform
1057+ def my_pass (_tape , my_option = None , my_other_option = None ): # pylint: disable=unused-argument
1058+ """A dummy qml.transform."""
1059+ return
1060+
1061+ register_transform (my_pass , "my-pass" , False )
1062+
1063+ @qjit (target = "mlir" )
1064+ @partial (my_pass , my_option = "my_option_value" , my_other_option = False )
1065+ @qml .qnode (qml .device (backend , wires = 1 ))
1066+ def captured_circuit ():
1067+ return qml .expval (qml .PauliZ (0 ))
1068+
1069+ capture_mlir = captured_circuit .mlir
1070+ assert 'transform.apply_registered_pass "my-pass"' in capture_mlir
1071+ assert (
1072+ 'with options = {"my-option" = "my_option_value", "my-other-option" = false}'
1073+ in capture_mlir
1074+ )
1075+
10511076 def test_transform_cancel_inverses_workflow (self , backend ):
10521077 """Test the integration for a circuit with a 'cancel_inverses' transform."""
10531078
@@ -1234,11 +1259,9 @@ def captured_circuit(U: ShapedArray([2, 2], float)):
12341259
12351260 # Catalyst 'cancel_inverses' should have been scheduled as a pass
12361261 # whereas PL 'unitary_to_rot' should have been expanded
1237- assert (
1238- 'transform.apply_registered_pass "remove-chained-self-inverse"'
1239- in captured_inverses_unitary .mlir
1240- )
1241- assert is_unitary_rotated (captured_inverses_unitary .mlir )
1262+ capture_mlir = captured_inverses_unitary .mlir
1263+ assert 'transform.apply_registered_pass "remove-chained-self-inverse"' in capture_mlir
1264+ assert is_unitary_rotated (capture_mlir )
12421265
12431266 # Case 2: During plxpr interpretation, first comes the PL transform
12441267 # without Catalyst counterpart, second comes the PL transform with it
@@ -1251,12 +1274,10 @@ def captured_circuit(U: ShapedArray([2, 2], float)):
12511274
12521275 # Both PL transforms should have been expaned and no Catalyst pass should have been
12531276 # scheduled
1254- assert (
1255- 'transform.apply_registered_pass "remove-chained-self-inverse"'
1256- not in captured_unitary_inverses .mlir
1257- )
1258- assert 'quantum.custom "Hadamard"' not in captured_unitary_inverses .mlir
1259- assert is_unitary_rotated (captured_unitary_inverses .mlir )
1277+ capture_mlir = captured_unitary_inverses .mlir
1278+ assert 'transform.apply_registered_pass "remove-chained-self-inverse"' not in capture_mlir
1279+ assert 'quantum.custom "Hadamard"' not in capture_mlir
1280+ assert is_unitary_rotated (capture_mlir )
12601281
12611282 qml .capture .disable ()
12621283
@@ -1446,11 +1467,12 @@ def captured_circuit():
14461467
14471468 capture_result = captured_circuit ()
14481469
1470+ capture_mlir = captured_circuit .mlir
14491471 assert is_controlled_pushed_back (
1450- captured_circuit . mlir , 'quantum.custom "RX"' , 'quantum.custom "CNOT"'
1472+ capture_mlir , 'quantum.custom "RX"' , 'quantum.custom "CNOT"'
14511473 )
14521474 assert is_controlled_pushed_back (
1453- captured_circuit . mlir , 'quantum.custom "PauliX"' , 'quantum.custom "CRX"'
1475+ capture_mlir , 'quantum.custom "PauliX"' , 'quantum.custom "CRX"'
14541476 )
14551477
14561478 qml .capture .disable ()
0 commit comments