|
1 | 1 | import unittest |
2 | 2 | import random |
| 3 | +import warnings |
3 | 4 | import numpy as np |
4 | 5 | import networkx as nx |
5 | 6 | import treelib |
@@ -270,7 +271,13 @@ def test_majority_voting(self): |
270 | 271 |
|
271 | 272 | # Test tie scenarios (should default to +1) |
272 | 273 | votes = np.array([[1, -1], [-1, 1]]) |
273 | | - result = majority_voting(votes) |
| 274 | + with warnings.catch_warnings(record=True) as w: |
| 275 | + warnings.simplefilter("always") |
| 276 | + result = majority_voting(votes) |
| 277 | + # Check that a warning was issued |
| 278 | + self.assertEqual(len(w), 1) |
| 279 | + self.assertTrue(issubclass(w[0].category, UserWarning)) |
| 280 | + self.assertIn("Zero elements encountered", str(w[0].message)) |
274 | 281 | expected = np.array([1, 1]) # Ties resolved to +1 |
275 | 282 | np.testing.assert_array_equal(result, expected) |
276 | 283 |
|
@@ -346,7 +353,10 @@ def test_generate_random_forest_large_qubits_no_save(self): |
346 | 353 | samples = [np.random.rand(2**num_qubits) for _ in range(num_qubits+1)] |
347 | 354 |
|
348 | 355 | # This should work without trying to create visualizations |
349 | | - result = generate_random_forest(num_qubits, num_trees, samples, save_tree=True) |
| 356 | + # May generate warnings from majority voting due to random data |
| 357 | + with warnings.catch_warnings(): |
| 358 | + warnings.simplefilter("ignore", UserWarning) |
| 359 | + result = generate_random_forest(num_qubits, num_trees, samples, save_tree=True) |
350 | 360 |
|
351 | 361 | self.assertEqual(result.shape, (2**num_qubits,)) |
352 | 362 | self.assertTrue(np.all(np.isin(result, [-1, 1]))) |
|
0 commit comments