|
| 1 | +# This code is part of Qiskit. |
| 2 | +# |
| 3 | +# (C) Copyright IBM 2022. |
| 4 | +# |
| 5 | +# This code is licensed under the Apache License, Version 2.0. You may |
| 6 | +# obtain a copy of this license in the LICENSE.txt file in the root directory |
| 7 | +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. |
| 8 | +# |
| 9 | +# Any modifications or derivative works of this code must retain this |
| 10 | +# copyright notice, and modified files need to carry a notice indicating |
| 11 | +# that they have been altered from the originals. |
| 12 | +""" |
| 13 | +Estimator class |
| 14 | +""" |
| 15 | + |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +from collections.abc import Iterable, Sequence |
| 19 | +from typing import cast |
| 20 | + |
| 21 | +import numpy as np |
| 22 | + |
| 23 | +from qiskit.circuit import Parameter, QuantumCircuit |
| 24 | +from qiskit.exceptions import QiskitError |
| 25 | +from qiskit.opflow import PauliSumOp |
| 26 | +from qiskit.quantum_info import Statevector |
| 27 | +from qiskit.quantum_info.operators.base_operator import BaseOperator |
| 28 | + |
| 29 | +from .base_estimator import BaseEstimator |
| 30 | +from .estimator_result import EstimatorResult |
| 31 | +from .utils import init_circuit, init_observable |
| 32 | + |
| 33 | + |
| 34 | +class Estimator(BaseEstimator): |
| 35 | + """ |
| 36 | + Estimator class |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + circuits: QuantumCircuit | Iterable[QuantumCircuit], |
| 42 | + observables: BaseOperator | PauliSumOp | Iterable[BaseOperator | PauliSumOp], |
| 43 | + parameters: Iterable[Iterable[Parameter]] | None = None, |
| 44 | + ): |
| 45 | + if isinstance(circuits, QuantumCircuit): |
| 46 | + circuits = [circuits] |
| 47 | + circuits = [init_circuit(circuit) for circuit in circuits] |
| 48 | + |
| 49 | + if isinstance(observables, (PauliSumOp, BaseOperator)): |
| 50 | + observables = [observables] |
| 51 | + observables = [init_observable(observable) for observable in observables] |
| 52 | + |
| 53 | + super().__init__( |
| 54 | + circuits=circuits, |
| 55 | + observables=observables, |
| 56 | + parameters=parameters, |
| 57 | + ) |
| 58 | + self._is_closed = False |
| 59 | + |
| 60 | + def __call__( |
| 61 | + self, |
| 62 | + circuits: Sequence[int] | None = None, |
| 63 | + observables: Sequence[int] | None = None, |
| 64 | + parameters: Sequence[Sequence[float]] | Sequence[float] | None = None, |
| 65 | + **run_options, |
| 66 | + ) -> EstimatorResult: |
| 67 | + if self._is_closed: |
| 68 | + raise QiskitError("The primitive has been closed.") |
| 69 | + |
| 70 | + if parameters and not isinstance(parameters[0], Sequence): |
| 71 | + parameters = cast("Sequence[float]", parameters) |
| 72 | + parameters = [parameters] |
| 73 | + if ( |
| 74 | + circuits is None |
| 75 | + and len(self._circuits) == 1 |
| 76 | + and observables is None |
| 77 | + and len(self._observables) == 1 |
| 78 | + and parameters is not None |
| 79 | + ): |
| 80 | + circuits = [0] * len(parameters) |
| 81 | + observables = [0] * len(parameters) |
| 82 | + if circuits is None: |
| 83 | + circuits = list(range(len(self._circuits))) |
| 84 | + if observables is None: |
| 85 | + observables = list(range(len(self._observables))) |
| 86 | + if parameters is None: |
| 87 | + parameters = [[]] * len(circuits) |
| 88 | + if len(circuits) != len(parameters): |
| 89 | + raise QiskitError( |
| 90 | + f"The number of circuits ({len(circuits)}) does not match " |
| 91 | + f"the number of parameter sets ({len(parameters)})." |
| 92 | + ) |
| 93 | + |
| 94 | + bound_circuits = [] |
| 95 | + for i, value in zip(circuits, parameters): |
| 96 | + if len(value) != len(self._parameters[i]): |
| 97 | + raise QiskitError( |
| 98 | + f"The number of values ({len(value)}) does not match " |
| 99 | + f"the number of parameters ({len(self._parameters[i])})." |
| 100 | + ) |
| 101 | + bound_circuits.append( |
| 102 | + self._circuits[i].bind_parameters(dict(zip(self._parameters[i], value))) |
| 103 | + ) |
| 104 | + sorted_observables = [self._observables[i] for i in observables] |
| 105 | + expectation_values = [] |
| 106 | + for circ, obs in zip(bound_circuits, sorted_observables): |
| 107 | + if circ.num_qubits != obs.num_qubits: |
| 108 | + raise QiskitError( |
| 109 | + f"The number of qubits of a circuit ({circ.num_qubits}) does not match " |
| 110 | + f"the number of qubits of a observable ({obs.num_qubits})." |
| 111 | + ) |
| 112 | + expectation_values.append(Statevector(circ).expectation_value(obs)) |
| 113 | + |
| 114 | + return EstimatorResult(np.real_if_close(expectation_values), []) |
| 115 | + |
| 116 | + def close(self): |
| 117 | + self._is_closed = True |
0 commit comments