Skip to content

Commit 3abae5c

Browse files
Make assign_states public, faster and remove a bug (#7880)
**Context:** In `pennylane.templates.state_preparations.superposition.py::_assign_states`, some computational basis states are mapped to new computational basis states. It is claimed that states that are both in the set of inputs and the set of outputs are fixed points of the produced map, but this is not true. `_assign_states` is used in `qualtran_io.py` but is a private function, so[ it might be better to make it public.](https://github.com/PennyLaneAI/pennylane/pull/7866/files#r2205072082) **Description of the Change:** - Make the mentioned states fixed points of the produced map - Speed up the function by looping over the set of states fewer times - Make `assign_states` public-facing. **Benefits:** Better import hygiene, faster code, one bug less. **Possible Drawbacks:** N/A **Related GitHub Issues:** --------- Co-authored-by: Yushao Chen (Jerry) <[email protected]>
1 parent bb628a2 commit 3abae5c

File tree

5 files changed

+116
-62
lines changed

5 files changed

+116
-62
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@
6464

6565
<h3>Improvements 🛠</h3>
6666

67+
* Changed how basis states are assigned internally in `qml.Superposition`, improving its
68+
decomposition slightly both regarding classical computing time and gate decomposition.
69+
[(#7880)](https://github.com/PennyLaneAI/pennylane/pull/7880)
70+
6771
* The printing and drawing of :class:`~.TemporaryAND`, also known as ``qml.Elbow``, and its adjoint
6872
have been improved to be more legible and consistent with how it's depicted in circuits in the literature.
6973
[(#8017)](https://github.com/PennyLaneAI/pennylane/pull/8017)

pennylane/io/qualtran_io.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pennylane.queuing import AnnotatedQueue, QueuingManager
3333
from pennylane.registers import registers
3434
from pennylane.tape import make_qscript
35-
from pennylane.templates.state_preparations.superposition import _assign_states
35+
from pennylane.templates.state_preparations.superposition import order_states
3636
from pennylane.wires import WiresLike
3737
from pennylane.workflow import construct_tape
3838
from pennylane.workflow.qnode import QNode
@@ -124,7 +124,7 @@ def _(op: qtemps.state_preparations.Superposition):
124124
size_basis_state = len(bases[0]) # assuming they are all the same size
125125

126126
dic_state = dict(zip(bases, coeffs))
127-
perms = _assign_states(bases)
127+
perms = order_states(bases)
128128
new_dic_state = {perms[key]: val for key, val in dic_state.items() if key in perms}
129129

130130
sorted_coefficients = [
@@ -387,7 +387,6 @@ def _(op: qtemps.subroutines.ModExp):
387387
num_work_wires = len(op.hyperparameters["work_wires"])
388388
num_x_wires = len(op.hyperparameters["x_wires"])
389389

390-
mult_resources = {}
391390
if mod == 2**num_x_wires:
392391
num_aux_wires = num_x_wires
393392
num_aux_swap = num_x_wires
@@ -407,12 +406,13 @@ def _(op: qtemps.subroutines.ModExp):
407406

408407
cnot = qt_gates.CNOT()
409408

410-
mult_resources = {}
411-
mult_resources[qft] = 2
412-
mult_resources[qft_dag] = 2
413-
mult_resources[sequence] = 1
414-
mult_resources[sequence_dag] = 1
415-
mult_resources[cnot] = min(num_x_wires, num_aux_swap)
409+
mult_resources = {
410+
qft: 2,
411+
qft_dag: 2,
412+
sequence: 1,
413+
sequence_dag: 1,
414+
cnot: min(num_x_wires, num_aux_swap),
415+
}
416416

417417
gate_types = defaultdict(int, {})
418418
ctrl_spec = CtrlSpec(cvs=[1])

pennylane/templates/state_preparations/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020
from .basis_qutrit import QutritBasisStatePreparation
2121
from .cosine_window import CosineWindow
2222
from .mottonen import MottonenStatePreparation
23-
from .superposition import Superposition
23+
from .superposition import Superposition, order_states
2424
from .qrom_state_prep import QROMStatePreparation
2525
from .state_prep_mps import MPSPrep, right_canonicalize_mps

pennylane/templates/state_preparations/superposition.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,81 +14,77 @@
1414
r"""
1515
Contains the Superposition template.
1616
"""
17-
1817
import pennylane as qml
1918
from pennylane.operation import Operation
2019

2120

22-
def _assign_states(basis_list):
21+
def order_states(basis_states: list[list[int]]) -> dict[tuple[int], tuple[int]]:
2322
r"""
24-
This function maps a given list of :math:`m` basis states to the first :math:`m` basis states in the
25-
computational basis.
26-
27-
For instance, a given list of :math:`[s_0, s_1, ..., s_m]` where :math:`s` is a basis
28-
state of length :math:`4` will be mapped as :math:`{s_0: |0000\rangle, s_1: |0001\rangle, s_2: |0010\rangle, \dots}`.
29-
30-
Note that if a state in ``basis_list`` is one of the first :math:`m` basis states,
31-
this state will be mapped to itself.
23+
This function maps a given list of :math:`m` computational basis states to the first
24+
:math:`m` computational basis states, except for input states that are among the first
25+
:math:`m` computational basis states, which are mapped to themselves.
3226
3327
Args:
34-
basis_list (list): list of basis states to be mapped
28+
basis_states (list[list[int]]): sequence of :math:`m` basis states to be mapped.
29+
Each state is a sequence of 0s and 1s.
3530
3631
Returns:
37-
dict: dictionary mapping basis states to the first :math:`m` basis states
32+
dict[tuple[int], tuple[int]]: dictionary mapping basis states to the first :math:`m` basis
33+
states, except for fixed points (states in the input that already were among the
34+
first :math:`m` basis states).
3835
36+
**Example**
3937
40-
** Example **
38+
For instance, a given list of :math:`[s_0, s_1, ..., s_m]` where :math:`s` is a basis
39+
state of length :math:`4` will be mapped as
40+
:math:`\{s_0: |0000\rangle, s_1: |0001\rangle, s_2: |0010\rangle, \dots\}`.
4141
4242
.. code-block:: pycon
4343
44-
>>> basis_list = [[1, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 0, 1]]
45-
>>> _assign_states(basis_list)
46-
{
47-
[1, 1, 0, 0]: [0, 0, 0, 0],
48-
[1, 0, 1, 0]: [0, 0, 0, 1],
49-
[0, 1, 0, 1]: [0, 0, 1, 0],
50-
[1, 0, 0, 1]: [0, 0, 1, 1]
51-
}
44+
>>> basis_states = [[1, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 0, 1]]
45+
>>> order_states(basis_states)
46+
{(1, 1, 0, 0): (0, 0, 0, 0),
47+
(1, 0, 1, 0): (0, 0, 0, 1),
48+
(0, 1, 0, 1): (0, 0, 1, 0),
49+
(1, 0, 0, 1): (0, 0, 1, 1)}
5250
51+
If a state in ``basis_states`` is one of the first :math:`m` basis states,
52+
this state will be mapped to itself, i.e. it will be a fixed point of the mapping.
5353
5454
.. code-block:: pycon
5555
56-
>>> basis_list = [[1, 1, 0, 0], [0, 1, 0, 1], [0, 0, 0, 1], [1, 0, 0, 1]]
57-
>>> _assign_states(basis_list)
58-
{
59-
[1, 1, 0, 0]: [0, 0, 0, 0],
60-
[0, 1, 0, 1]: [0, 0, 1, 0],
61-
[0, 0, 0, 1]: [0, 0, 0, 1],
62-
[1, 0, 0, 1]: [0, 0, 1, 1]
63-
}
56+
>>> basis_states = [[1, 1, 0, 0], [0, 1, 0, 1], [0, 0, 0, 1], [1, 0, 0, 1]]
57+
>>> order_states(basis_states)
58+
{(0, 0, 0, 1): (0, 0, 0, 1),
59+
(1, 1, 0, 0): (0, 0, 0, 0),
60+
(0, 1, 0, 1): (0, 0, 1, 0),
61+
(1, 0, 0, 1): (0, 0, 1, 1)}
6462
6563
"""
6664

67-
length = len(basis_list[0])
68-
smallest_basis_lists = [tuple(map(int, f"{i:0{length}b}")) for i in range(len(basis_list))]
69-
70-
binary_dict = {}
71-
used_smallest = set()
72-
73-
# Assign keys that can map to themselves
74-
for original in basis_list:
75-
76-
if original in smallest_basis_lists and tuple(original) not in used_smallest:
77-
78-
binary_dict[tuple(original)] = original
79-
used_smallest.add(tuple(original))
65+
m = len(basis_states)
66+
length = len(basis_states[0])
67+
# Create the integers corresponding to the input basis states
68+
basis_ints = [int("".join(map(str, state)), 2) for state in basis_states]
8069

81-
# Assign remaining keys to unused binary lists
82-
remaining_keys = [key for key in basis_list if tuple(key) not in binary_dict]
83-
remaining_values = [
84-
value for value in smallest_basis_lists if tuple(value) not in used_smallest
85-
]
70+
basis_states = [tuple(s) for s in basis_states] # Need hashable objects, so we use tuples
71+
state_map = {} # The map for basis states to be populated
72+
unmapped_states = [] # Will collect non-fixed point states
73+
unmapped_ints = {i: None for i in range(m)} # Will remove fixed point states
74+
# Map fixed-point states to themselves and collect states and target ints still to be paired
75+
for b_int, state in zip(basis_ints, basis_states):
76+
if b_int < m:
77+
state_map[state] = state
78+
unmapped_ints.pop(b_int)
79+
else:
80+
unmapped_states.append(state)
8681

87-
for key, value in zip(remaining_keys, remaining_values):
88-
binary_dict[tuple(key)] = value
89-
used_smallest.add(tuple(value))
82+
# Map non-fixed point states
83+
for state, new_b_int in zip(unmapped_states, unmapped_ints):
84+
# Convert the index of the state to be mapped into a state itself
85+
state_map[state] = tuple(map(int, f"{new_b_int:0{length}b}"))
9086

91-
return binary_dict
87+
return state_map
9288

9389

9490
def _permutation_operator(basis1, basis2, wires, work_wire):
@@ -292,7 +288,7 @@ def compute_decomposition(coeffs, bases, wires, work_wire): # pylint: disable=a
292288
"""
293289

294290
dic_state = dict(zip(bases, coeffs))
295-
perms = _assign_states(bases)
291+
perms = order_states(bases)
296292
new_dic_state = {perms[key]: dic_state[key] for key in dic_state if key in perms}
297293

298294
sorted_coefficients = [

tests/templates/test_state_preparations/test_superposition.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,60 @@
2222

2323
import pennylane as qml
2424
from pennylane import numpy as pnp
25+
from pennylane.templates.state_preparations.superposition import order_states
26+
27+
28+
def int_to_state(i, length):
29+
return tuple(map(int, f"{i:0{length}b}"))
30+
31+
32+
@pytest.mark.parametrize(
33+
"basis_states, exp_map",
34+
(
35+
[ # Examples where all basis states are fixed points
36+
(
37+
[int_to_state(i, L) for i in range(m)],
38+
{int_to_state(i, L): int_to_state(i, L) for i in range(m)},
39+
)
40+
for L, m in [(1, 2), (2, 4), (2, 3), (3, 7), (4, 3), (4, 16)]
41+
]
42+
+ [ # Examples from docstring
43+
(
44+
[[1, 1, 0, 0], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 0, 1]],
45+
{
46+
(1, 1, 0, 0): (0, 0, 0, 0),
47+
(1, 0, 1, 0): (0, 0, 0, 1),
48+
(0, 1, 0, 1): (0, 0, 1, 0),
49+
(1, 0, 0, 1): (0, 0, 1, 1),
50+
},
51+
),
52+
(
53+
[[1, 1, 0, 0], [0, 1, 0, 1], [0, 0, 0, 1], [1, 0, 0, 1]],
54+
{
55+
(0, 0, 0, 1): (0, 0, 0, 1),
56+
(1, 1, 0, 0): (0, 0, 0, 0),
57+
(0, 1, 0, 1): (0, 0, 1, 0),
58+
(1, 0, 0, 1): (0, 0, 1, 1),
59+
},
60+
),
61+
]
62+
+ [ # Other examples
63+
(
64+
[[1, 1, 0, 1], [0, 1, 0, 0], [1, 1, 1, 1], [0, 0, 1, 0], [0, 0, 0, 0]],
65+
{
66+
(0, 0, 0, 0): (0, 0, 0, 0),
67+
(0, 0, 1, 0): (0, 0, 1, 0),
68+
(0, 1, 0, 0): (0, 1, 0, 0),
69+
(1, 1, 0, 1): (0, 0, 0, 1),
70+
(1, 1, 1, 1): (0, 0, 1, 1),
71+
},
72+
),
73+
([[1, 1, 0], [0, 0, 1]], {(0, 0, 1): (0, 0, 1), (1, 1, 0): (0, 0, 0)}),
74+
]
75+
),
76+
)
77+
def test_order_states(basis_states, exp_map):
78+
assert order_states(basis_states) == exp_map
2579

2680

2781
def test_standard_validity():

0 commit comments

Comments
 (0)