Skip to content

Commit ceceb49

Browse files
ikkohamajavadiat-imamichi
authored
Reference implementation of Estimator and Sampler Primitives (Qiskit#7780)
* Merge pull request Qiskit#7651 from ikkoham/primitives/sampler-and-estimator * rename releasenote * add releasenote * move to qiskit/primitives * fix lint (cyclic import) * Update releasenotes/notes/primitives-fb4515ec0f4cbd8e.yaml Co-authored-by: Ali Javadi-Abhari <[email protected]> * fix sampler with empty parameters Co-authored-by: Ali Javadi-Abhari <[email protected]> Co-authored-by: Takashi Imamichi <[email protected]>
1 parent df2a6eb commit ceceb49

File tree

9 files changed

+947
-5
lines changed

9 files changed

+947
-5
lines changed

qiskit/primitives/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
:toctree: ../stubs/
3030
3131
BaseEstimator
32+
Estimator
3233
3334
Sampler
3435
=======
@@ -37,6 +38,7 @@
3738
:toctree: ../stubs/
3839
3940
BaseSampler
41+
Sampler
4042
4143
Results
4244
=======
@@ -50,5 +52,7 @@
5052

5153
from .base_estimator import BaseEstimator
5254
from .base_sampler import BaseSampler
55+
from .estimator import Estimator
5356
from .estimator_result import EstimatorResult
57+
from .sampler import Sampler
5458
from .sampler_result import SamplerResult

qiskit/primitives/estimator.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# This code is part of Qiskit.
2+
#
3+
# (C) Copyright IBM 2022.
4+
#
5+
# This code is licensed under the Apache License, Version 2.0. You may
6+
# obtain a copy of this license in the LICENSE.txt file in the root directory
7+
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8+
#
9+
# Any modifications or derivative works of this code must retain this
10+
# copyright notice, and modified files need to carry a notice indicating
11+
# that they have been altered from the originals.
12+
"""
13+
Estimator class
14+
"""
15+
16+
from __future__ import annotations
17+
18+
from collections.abc import Iterable, Sequence
19+
from typing import cast
20+
21+
import numpy as np
22+
23+
from qiskit.circuit import Parameter, QuantumCircuit
24+
from qiskit.exceptions import QiskitError
25+
from qiskit.opflow import PauliSumOp
26+
from qiskit.quantum_info import Statevector
27+
from qiskit.quantum_info.operators.base_operator import BaseOperator
28+
29+
from .base_estimator import BaseEstimator
30+
from .estimator_result import EstimatorResult
31+
from .utils import init_circuit, init_observable
32+
33+
34+
class Estimator(BaseEstimator):
35+
"""
36+
Estimator class
37+
"""
38+
39+
def __init__(
40+
self,
41+
circuits: QuantumCircuit | Iterable[QuantumCircuit],
42+
observables: BaseOperator | PauliSumOp | Iterable[BaseOperator | PauliSumOp],
43+
parameters: Iterable[Iterable[Parameter]] | None = None,
44+
):
45+
if isinstance(circuits, QuantumCircuit):
46+
circuits = [circuits]
47+
circuits = [init_circuit(circuit) for circuit in circuits]
48+
49+
if isinstance(observables, (PauliSumOp, BaseOperator)):
50+
observables = [observables]
51+
observables = [init_observable(observable) for observable in observables]
52+
53+
super().__init__(
54+
circuits=circuits,
55+
observables=observables,
56+
parameters=parameters,
57+
)
58+
self._is_closed = False
59+
60+
def __call__(
61+
self,
62+
circuits: Sequence[int] | None = None,
63+
observables: Sequence[int] | None = None,
64+
parameters: Sequence[Sequence[float]] | Sequence[float] | None = None,
65+
**run_options,
66+
) -> EstimatorResult:
67+
if self._is_closed:
68+
raise QiskitError("The primitive has been closed.")
69+
70+
if parameters and not isinstance(parameters[0], Sequence):
71+
parameters = cast("Sequence[float]", parameters)
72+
parameters = [parameters]
73+
if (
74+
circuits is None
75+
and len(self._circuits) == 1
76+
and observables is None
77+
and len(self._observables) == 1
78+
and parameters is not None
79+
):
80+
circuits = [0] * len(parameters)
81+
observables = [0] * len(parameters)
82+
if circuits is None:
83+
circuits = list(range(len(self._circuits)))
84+
if observables is None:
85+
observables = list(range(len(self._observables)))
86+
if parameters is None:
87+
parameters = [[]] * len(circuits)
88+
if len(circuits) != len(parameters):
89+
raise QiskitError(
90+
f"The number of circuits ({len(circuits)}) does not match "
91+
f"the number of parameter sets ({len(parameters)})."
92+
)
93+
94+
bound_circuits = []
95+
for i, value in zip(circuits, parameters):
96+
if len(value) != len(self._parameters[i]):
97+
raise QiskitError(
98+
f"The number of values ({len(value)}) does not match "
99+
f"the number of parameters ({len(self._parameters[i])})."
100+
)
101+
bound_circuits.append(
102+
self._circuits[i].bind_parameters(dict(zip(self._parameters[i], value)))
103+
)
104+
sorted_observables = [self._observables[i] for i in observables]
105+
expectation_values = []
106+
for circ, obs in zip(bound_circuits, sorted_observables):
107+
if circ.num_qubits != obs.num_qubits:
108+
raise QiskitError(
109+
f"The number of qubits of a circuit ({circ.num_qubits}) does not match "
110+
f"the number of qubits of a observable ({obs.num_qubits})."
111+
)
112+
expectation_values.append(Statevector(circ).expectation_value(obs))
113+
114+
return EstimatorResult(np.real_if_close(expectation_values), [])
115+
116+
def close(self):
117+
self._is_closed = True

qiskit/primitives/sampler.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# This code is part of Qiskit.
2+
#
3+
# (C) Copyright IBM 2022.
4+
#
5+
# This code is licensed under the Apache License, Version 2.0. You may
6+
# obtain a copy of this license in the LICENSE.txt file in the root directory
7+
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8+
#
9+
# Any modifications or derivative works of this code must retain this
10+
# copyright notice, and modified files need to carry a notice indicating
11+
# that they have been altered from the originals.
12+
"""
13+
Sampler class
14+
"""
15+
16+
from __future__ import annotations
17+
18+
from collections.abc import Iterable, Sequence
19+
20+
from qiskit.circuit import Parameter, QuantumCircuit
21+
from qiskit.exceptions import QiskitError
22+
from qiskit.quantum_info import Statevector
23+
from qiskit.result import QuasiDistribution
24+
25+
from .base_sampler import BaseSampler
26+
from .sampler_result import SamplerResult
27+
from .utils import final_measurement_mapping, init_circuit
28+
29+
30+
class Sampler(BaseSampler):
31+
"""
32+
Sampler class
33+
"""
34+
35+
def __init__(
36+
self,
37+
circuits: QuantumCircuit | Iterable[QuantumCircuit],
38+
parameters: Iterable[Iterable[Parameter]] | None = None,
39+
):
40+
"""
41+
Args:
42+
circuits: circuits to be executed
43+
44+
Raises:
45+
QiskitError: if some classical bits are not used for measurements.
46+
"""
47+
if isinstance(circuits, QuantumCircuit):
48+
circuits = [circuits]
49+
circuits = [init_circuit(circuit) for circuit in circuits]
50+
q_c_mappings = [final_measurement_mapping(circuit) for circuit in circuits]
51+
self._qargs_list = []
52+
for circuit, q_c_mapping in zip(circuits, q_c_mappings):
53+
if set(range(circuit.num_clbits)) != set(q_c_mapping.values()):
54+
raise QiskitError(
55+
"some classical bits are not used for measurements."
56+
f" the number of classical bits {circuit.num_clbits},"
57+
f" the used classical bits {set(q_c_mapping.values())}."
58+
)
59+
c_q_mapping = sorted((c, q) for q, c in q_c_mapping.items())
60+
self._qargs_list.append([q for _, q in c_q_mapping])
61+
circuits = [circuit.remove_final_measurements(inplace=False) for circuit in circuits]
62+
super().__init__(circuits, parameters)
63+
self._is_closed = False
64+
65+
def __call__(
66+
self,
67+
circuits: Sequence[int] | None = None,
68+
parameters: Sequence[Sequence[float]] | None = None,
69+
**run_options,
70+
) -> SamplerResult:
71+
if self._is_closed:
72+
raise QiskitError("The primitive has been closed.")
73+
74+
if circuits is None and parameters is not None and len(self._circuits) == 1:
75+
circuits = [0] * len(parameters)
76+
if circuits is None:
77+
circuits = list(range(len(self._circuits)))
78+
if parameters is None:
79+
parameters = [[]] * len(circuits)
80+
if len(circuits) != len(parameters):
81+
raise QiskitError(
82+
f"The number of circuits ({len(circuits)}) does not match "
83+
f"the number of parameter sets ({len(parameters)})."
84+
)
85+
86+
bound_circuits_qargs = []
87+
for i, value in zip(circuits, parameters):
88+
if len(value) != len(self._parameters[i]):
89+
raise QiskitError(
90+
f"The number of values ({len(value)}) does not match "
91+
f"the number of parameters ({len(self._parameters[i])})."
92+
)
93+
bound_circuits_qargs.append(
94+
(
95+
self._circuits[i].bind_parameters(dict(zip(self._parameters[i], value))),
96+
self._qargs_list[i],
97+
)
98+
)
99+
probabilities = [
100+
Statevector(circ).probabilities(qargs=qargs) for circ, qargs in bound_circuits_qargs
101+
]
102+
quasis = [QuasiDistribution(dict(enumerate(p))) for p in probabilities]
103+
104+
return SamplerResult(quasis, [{}] * len(circuits))
105+
106+
def close(self):
107+
self._is_closed = True

qiskit/primitives/utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# This code is part of Qiskit.
2+
#
3+
# (C) Copyright IBM 2022.
4+
#
5+
# This code is licensed under the Apache License, Version 2.0. You may
6+
# obtain a copy of this license in the LICENSE.txt file in the root directory
7+
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8+
#
9+
# Any modifications or derivative works of this code must retain this
10+
# copyright notice, and modified files need to carry a notice indicating
11+
# that they have been altered from the originals.
12+
"""
13+
Utility functions for primitives
14+
"""
15+
16+
from __future__ import annotations
17+
18+
from qiskit.circuit import ParameterExpression, QuantumCircuit
19+
from qiskit.extensions.quantum_initializer.initializer import Initialize
20+
from qiskit.opflow import PauliSumOp
21+
from qiskit.quantum_info import SparsePauliOp, Statevector
22+
from qiskit.quantum_info.operators.base_operator import BaseOperator
23+
24+
25+
def init_circuit(state: QuantumCircuit | Statevector) -> QuantumCircuit:
26+
"""Initialize state."""
27+
if isinstance(state, QuantumCircuit):
28+
return state
29+
if not isinstance(state, Statevector):
30+
state = Statevector(state)
31+
qc = QuantumCircuit(state.num_qubits)
32+
qc.append(Initialize(state), qargs=range(state.num_qubits))
33+
return qc
34+
35+
36+
def init_observable(observable: BaseOperator | PauliSumOp) -> SparsePauliOp:
37+
"""Initialize observable"""
38+
if isinstance(observable, SparsePauliOp):
39+
return observable
40+
if isinstance(observable, PauliSumOp):
41+
if isinstance(observable.coeff, ParameterExpression):
42+
raise TypeError(
43+
f"observable must have numerical coefficient, not {type(observable.coeff)}"
44+
)
45+
return observable.coeff * observable.primitive
46+
if isinstance(observable, BaseOperator):
47+
return SparsePauliOp.from_operator(observable)
48+
return SparsePauliOp(observable)
49+
50+
51+
def final_measurement_mapping(circuit: QuantumCircuit) -> dict[int, int]:
52+
"""Return the final measurement mapping for the circuit.
53+
54+
Dict keys label measured qubits, whereas the values indicate the
55+
classical bit onto which that qubits measurement result is stored.
56+
57+
Note: this function is a slightly simplified version of a utility function
58+
``_final_measurement_mapping`` of
59+
`mthree <https://github.com/Qiskit-Partners/mthree>`_.
60+
61+
Parameters:
62+
circuit: Input Qiskit QuantumCircuit.
63+
64+
Returns:
65+
Mapping of qubits to classical bits for final measurements.
66+
"""
67+
active_qubits = list(range(circuit.num_qubits))
68+
active_cbits = list(range(circuit.num_clbits))
69+
70+
# Find final measurements starting in back
71+
mapping = {}
72+
for item in circuit._data[::-1]:
73+
if item[0].name == "measure":
74+
cbit = circuit.find_bit(item[2][0]).index
75+
qbit = circuit.find_bit(item[1][0]).index
76+
if cbit in active_cbits and qbit in active_qubits:
77+
mapping[qbit] = cbit
78+
active_cbits.remove(cbit)
79+
active_qubits.remove(qbit)
80+
elif item[0].name != "barrier":
81+
for qq in item[1]:
82+
_temp_qubit = circuit.find_bit(qq).index
83+
if _temp_qubit in active_qubits:
84+
active_qubits.remove(_temp_qubit)
85+
86+
if not active_cbits or not active_qubits:
87+
break
88+
89+
# Sort so that classical bits are in numeric order low->high.
90+
mapping = dict(sorted(mapping.items(), key=lambda item: item[1]))
91+
return mapping

releasenotes/notes/base_primitives-fb4515ec0f4cbd8e.yaml

Lines changed: 0 additions & 5 deletions
This file was deleted.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
features:
3+
- |
4+
Added new abstract base primitive types under :mod:`qiskit.primitives`.
5+
There are two types of primitives, Sampler and Estimator.
6+
The reference implementation can be found in :mod:`qiskit.primitives`.
7+
Other concrete implementations will come from providers.

test/python/primitives/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# This code is part of Qiskit.
2+
#
3+
# (C) Copyright IBM 2022.
4+
#
5+
# This code is licensed under the Apache License, Version 2.0. You may
6+
# obtain a copy of this license in the LICENSE.txt file in the root directory
7+
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
8+
#
9+
# Any modifications or derivative works of this code must retain this
10+
# copyright notice, and modified files need to carry a notice indicating
11+
# that they have been altered from the originals.
12+
13+
"""Tests for the primitives."""

0 commit comments

Comments
 (0)