Skip to content

Commit d59bb0f

Browse files
committed
Remove unused get_samples() function for noiseless simulation
This can be replaced by using the get_samples_noisy() and a noiseless aer simulator.
1 parent 0d5a273 commit d59bb0f

File tree

4 files changed

+6
-60
lines changed

4 files changed

+6
-60
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,6 @@ cython_debug/
178178
.pypirc
179179

180180
# Ignore everything under tutorial/forest gallery
181+
/forest\ gallery
181182
/tutorials/forest\ gallery/
182183
CLAUDE.md

hadamard_random_forest/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from .sample import (
2525
get_circuits,
2626
get_circuits_hardware,
27-
get_samples,
2827
get_samples_noisy,
2928
get_samples_hardware,
3029
get_statevector
@@ -48,7 +47,6 @@
4847
"generate_random_forest",
4948
"get_circuits",
5049
"get_circuits_hardware",
51-
"get_samples",
5250
"get_samples_noisy",
5351
"get_samples_hardware",
5452
"get_statevector",

hadamard_random_forest/sample.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from qiskit.providers import Backend
1515
from qiskit.transpiler import generate_preset_pass_manager
1616
from qiskit_ibm_runtime import SamplerV2 as Sampler
17-
from qiskit_aer.primitives import Sampler as Aer_Sampler
1817
from mthree import M3Mitigation
1918
import mthree.utils as mthree_utils
2019

@@ -47,40 +46,6 @@ def get_circuits(
4746
return circuits
4847

4948

50-
def get_samples(
51-
num_qubits: int,
52-
sampler: Aer_Sampler | Sampler,
53-
circuits: List[qiskit.QuantumCircuit],
54-
parameters: np.ndarray
55-
) -> List[np.ndarray]:
56-
"""
57-
Execute circuits and collect probability distributions using a noiseless sampler.
58-
59-
Args:
60-
num_qubits: Number of qubits (defines statevector size 2**num_qubits).
61-
sampler: Sampler object providing run().result().quasi_dists.
62-
circuits: List of QuantumCircuit to execute.
63-
parameters: 1D array of parameter values to bind to each circuit.
64-
65-
Returns:
66-
List of 1D numpy arrays of length 2**num_qubits representing probabilities.
67-
"""
68-
n = len(circuits)
69-
if isinstance(sampler, Aer_Sampler):
70-
results = sampler.run(circuits, [parameters] * n).result().quasi_dists
71-
elif isinstance(sampler, Sampler):
72-
results = sampler.run([(qc, parameters) for qc in circuits]).result()[0].data.meas.get_counts()
73-
else:
74-
raise ValueError("Sampler must be of type qiskit_aer.primitives.Sampler or qiskit_ibm_runtime.SamplerV2.")
75-
76-
samples: List[np.ndarray] = []
77-
for res in results:
78-
proba = np.zeros(2**num_qubits, dtype=float)
79-
for idx, val in res.items():
80-
proba[idx] = val
81-
samples.append(proba)
82-
return samples
83-
8449

8550
def get_samples_noisy(
8651
num_qubits: int,

tests/test_sample.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import unittest
22
import numpy as np
33
from qiskit.circuit.random import random_circuit
4-
from qiskit_aer.primitives import Sampler as Aer_Sampler
4+
from qiskit_aer import AerSimulator
55
from qiskit.circuit.library import real_amplitudes
66
from hadamard_random_forest.sample import (
77
get_statevector,
88
get_circuits,
9-
get_samples
9+
get_samples_noisy
1010
)
1111

1212
class TestSample(unittest.TestCase):
@@ -26,25 +26,6 @@ def test_get_circuits(self):
2626
self.assertEqual(circuit.num_qubits, num_qubits)
2727
self.assertIsNotNone(circuit)
2828

29-
def test_get_samples(self):
30-
"""Test the get_samples function."""
31-
num_qubits = 3
32-
sampler = Aer_Sampler()
33-
base_circuit = real_amplitudes(num_qubits)
34-
circuits = get_circuits(num_qubits, base_circuit)
35-
parameters = np.random.rand(base_circuit.num_parameters)
36-
samples = get_samples(num_qubits, sampler, circuits, parameters)
37-
38-
# Original tests
39-
self.assertIsInstance(samples, list)
40-
self.assertEqual(len(samples), 4)
41-
42-
# Enhanced tests
43-
for sample in samples:
44-
self.assertIsInstance(sample, np.ndarray)
45-
self.assertEqual(sample.shape, (8,)) # 2^3 = 8
46-
self.assertTrue(np.all(sample >= 0)) # Non-negative probabilities
47-
self.assertAlmostEqual(np.sum(sample), 1.0, places=10) # Normalized
4829

4930
def test_get_statevector(self):
5031
"""Test the get_statevector function."""
@@ -69,13 +50,14 @@ def test_integration(self):
6950
"""Test the complete workflow integration."""
7051
num_qubits = 2 # Smaller for faster testing
7152
num_trees = 3
72-
sampler = Aer_Sampler()
53+
backend_sim = AerSimulator()
7354
base_circuit = real_amplitudes(num_qubits)
7455
parameters = np.random.rand(base_circuit.num_parameters)
56+
shots = 1024
7557

7658
# Complete workflow
7759
circuits = get_circuits(num_qubits, base_circuit)
78-
samples = get_samples(num_qubits, sampler, circuits, parameters)
60+
samples = get_samples_noisy(num_qubits, circuits, shots, parameters, backend_sim, error_mitigation=False)
7961
statevector = get_statevector(num_qubits, num_trees, samples, save_tree=False)
8062

8163
# Validate end-to-end

0 commit comments

Comments
 (0)