Skip to content

Commit 7747325

Browse files
committed
Enhance test files for the random forest module
1 parent d59bb0f commit 7747325

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tests/test_random_forest.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import random
3+
import warnings
34
import numpy as np
45
import networkx as nx
56
import treelib
@@ -270,7 +271,13 @@ def test_majority_voting(self):
270271

271272
# Test tie scenarios (should default to +1)
272273
votes = np.array([[1, -1], [-1, 1]])
273-
result = majority_voting(votes)
274+
with warnings.catch_warnings(record=True) as w:
275+
warnings.simplefilter("always")
276+
result = majority_voting(votes)
277+
# Check that a warning was issued
278+
self.assertEqual(len(w), 1)
279+
self.assertTrue(issubclass(w[0].category, UserWarning))
280+
self.assertIn("Zero elements encountered", str(w[0].message))
274281
expected = np.array([1, 1]) # Ties resolved to +1
275282
np.testing.assert_array_equal(result, expected)
276283

@@ -346,7 +353,10 @@ def test_generate_random_forest_large_qubits_no_save(self):
346353
samples = [np.random.rand(2**num_qubits) for _ in range(num_qubits+1)]
347354

348355
# This should work without trying to create visualizations
349-
result = generate_random_forest(num_qubits, num_trees, samples, save_tree=True)
356+
# May generate warnings from majority voting due to random data
357+
with warnings.catch_warnings():
358+
warnings.simplefilter("ignore", UserWarning)
359+
result = generate_random_forest(num_qubits, num_trees, samples, save_tree=True)
350360

351361
self.assertEqual(result.shape, (2**num_qubits,))
352362
self.assertTrue(np.all(np.isin(result, [-1, 1])))

0 commit comments

Comments
 (0)