|
14 | 14 | ) |
15 | 15 |
|
16 | 16 | import warnings |
| 17 | +from typing import Any, Callable, Dict, List, Tuple, Union |
| 18 | +from unittest.mock import MagicMock |
17 | 19 |
|
18 | 20 | import balance.testutil |
19 | 21 | import numpy as np |
@@ -1342,16 +1344,25 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None: |
1342 | 1344 |
|
1343 | 1345 | # CBPS handles extreme cases by logging a warning about identical weights |
1344 | 1346 | # rather than raising an exception |
1345 | | - self.assertWarnsRegexp( |
1346 | | - "All weights are identical", |
1347 | | - balance_cbps.cbps, |
1348 | | - sample_df, |
1349 | | - sample_weights, |
1350 | | - target_df, |
1351 | | - target_weights, |
1352 | | - transformations=None, |
1353 | | - cbps_method="exact", |
1354 | | - ) |
| 1347 | + # Suppress PerfectSeparationWarning as it's expected with this extreme test data |
| 1348 | + with warnings.catch_warnings(): |
| 1349 | + warnings.filterwarnings( |
| 1350 | + "ignore", |
| 1351 | + message="Perfect separation or prediction detected", |
| 1352 | + category=PerfectSeparationWarning, |
| 1353 | + ) |
| 1354 | + self.assertWarnsRegexp( |
| 1355 | + "All weights are identical", |
| 1356 | + balance_cbps.cbps, |
| 1357 | + sample_df, |
| 1358 | + sample_weights, |
| 1359 | + target_df, |
| 1360 | + target_weights, |
| 1361 | + transformations=None, |
| 1362 | + cbps_method="exact", |
| 1363 | + ) |
| 1364 | + |
| 1365 | + def test_cbps_over_method_logs_warnings(self) -> None: |
1355 | 1366 | """Test CBPS over method logs warnings when optimization fails (lines 713, 747, 765). |
1356 | 1367 |
|
1357 | 1368 | Verifies that when optimization algorithms fail to converge, appropriate |
@@ -1381,38 +1392,45 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None: |
1381 | 1392 | # Run with over method to exercise gmm optimization paths |
1382 | 1393 | # Use very tight opt_opts to force convergence failure |
1383 | 1394 | # We expect either warnings to be logged or an exception to be raised |
1384 | | - try: |
1385 | | - with self.assertLogs(level=logging.WARNING) as log_context: |
1386 | | - balance_cbps.cbps( |
1387 | | - sample_df, |
1388 | | - sample_weights, |
1389 | | - target_df, |
1390 | | - target_weights, |
1391 | | - transformations=None, |
1392 | | - cbps_method="over", |
1393 | | - opt_opts={"maxiter": 1}, # Force convergence failure |
1394 | | - ) |
1395 | | - # Verify that at least one warning was logged |
1396 | | - self.assertTrue( |
1397 | | - len(log_context.records) > 0, |
1398 | | - msg="Expected warning logs when optimization fails to converge", |
1399 | | - ) |
1400 | | - except Exception as e: |
1401 | | - # If an exception is raised, verify it contains relevant error info |
1402 | | - error_msg = str(e).lower() |
1403 | | - self.assertTrue( |
1404 | | - any( |
1405 | | - keyword in error_msg |
1406 | | - for keyword in [ |
1407 | | - "converge", |
1408 | | - "constraint", |
1409 | | - "singular", |
1410 | | - "optimization", |
1411 | | - "failed", |
1412 | | - ] |
1413 | | - ), |
1414 | | - msg=f"Expected exception to contain convergence-related message, got: {e}", |
| 1395 | + # Suppress PerfectSeparationWarning as it's expected with this extreme test data |
| 1396 | + with warnings.catch_warnings(): |
| 1397 | + warnings.filterwarnings( |
| 1398 | + "ignore", |
| 1399 | + message="Perfect separation or prediction detected", |
| 1400 | + category=PerfectSeparationWarning, |
1415 | 1401 | ) |
| 1402 | + try: |
| 1403 | + with self.assertLogs(level=logging.WARNING) as log_context: |
| 1404 | + balance_cbps.cbps( |
| 1405 | + sample_df, |
| 1406 | + sample_weights, |
| 1407 | + target_df, |
| 1408 | + target_weights, |
| 1409 | + transformations=None, |
| 1410 | + cbps_method="over", |
| 1411 | + opt_opts={"maxiter": 1}, # Force convergence failure |
| 1412 | + ) |
| 1413 | + # Verify that at least one warning was logged |
| 1414 | + self.assertTrue( |
| 1415 | + len(log_context.records) > 0, |
| 1416 | + msg="Expected warning logs when optimization fails to converge", |
| 1417 | + ) |
| 1418 | + except Exception as e: |
| 1419 | + # If an exception is raised, verify it contains relevant error info |
| 1420 | + error_msg = str(e).lower() |
| 1421 | + self.assertTrue( |
| 1422 | + any( |
| 1423 | + keyword in error_msg |
| 1424 | + for keyword in [ |
| 1425 | + "converge", |
| 1426 | + "constraint", |
| 1427 | + "singular", |
| 1428 | + "optimization", |
| 1429 | + "failed", |
| 1430 | + ] |
| 1431 | + ), |
| 1432 | + msg=f"Expected exception to contain convergence-related message, got: {e}", |
| 1433 | + ) |
1416 | 1434 |
|
1417 | 1435 | def test_cbps_alpha_function_convergence_warning(self) -> None: |
1418 | 1436 | """Test CBPS logs warning when alpha_function fails to converge (line 689). |
@@ -1477,3 +1495,136 @@ def test_cbps_alpha_function_convergence_warning(self) -> None: |
1477 | 1495 | ), |
1478 | 1496 | msg=f"Expected exception to contain convergence-related message, got: {e}", |
1479 | 1497 | ) |
| 1498 | + |
| 1499 | + |
| 1500 | +class TestCbpsOptimizationConvergenceWithMocking(balance.testutil.BalanceTestCase): |
| 1501 | + """Test CBPS optimization convergence warning branches using mocking (lines 689, 713, 726, 747, 765, 778). |
| 1502 | +
|
| 1503 | + These tests use unittest.mock to directly control scipy.optimize.minimize return values, |
| 1504 | + ensuring the specific warning and exception branches in _cbps_optimization are executed. |
| 1505 | + """ |
| 1506 | + |
| 1507 | + def _create_simple_test_data( |
| 1508 | + self, |
| 1509 | + ) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]: |
| 1510 | + """Create simple test data for CBPS testing.""" |
| 1511 | + sample_df = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) |
| 1512 | + target_df = pd.DataFrame({"a": [2.0, 3.0, 4.0, 5.0, 6.0]}) |
| 1513 | + sample_weights = pd.Series([1.0] * 5) |
| 1514 | + target_weights = pd.Series([1.0] * 5) |
| 1515 | + return sample_df, sample_weights, target_df, target_weights |
| 1516 | + |
| 1517 | + def test_exact_method_constraint_violation_exception(self) -> None: |
| 1518 | + """Test line 726: Exception when exact method constraints can't be satisfied. |
| 1519 | +
|
| 1520 | + Uses mocking to simulate scipy.optimize.minimize returning success=False |
| 1521 | + with a specific constraint violation message. |
| 1522 | + """ |
| 1523 | + from unittest.mock import patch |
| 1524 | + |
| 1525 | + sample_df, sample_weights, target_df, target_weights = ( |
| 1526 | + self._create_simple_test_data() |
| 1527 | + ) |
| 1528 | + |
| 1529 | + def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock: |
| 1530 | + result = MagicMock() |
| 1531 | + # Simulate constraint violation failure |
| 1532 | + result.__getitem__ = lambda _, key: { |
| 1533 | + "success": np.bool_(False), |
| 1534 | + "message": "Did not converge to a solution satisfying the constraints", |
| 1535 | + "x": x0, |
| 1536 | + "fun": 100.0, |
| 1537 | + }[key] |
| 1538 | + return result |
| 1539 | + |
| 1540 | + def mock_minimize_scalar( |
| 1541 | + fun: Callable[..., Any], **kwargs: Any |
| 1542 | + ) -> Dict[str, Union[np.bool_, np.ndarray, str]]: |
| 1543 | + return { |
| 1544 | + "success": np.bool_(True), |
| 1545 | + "message": "Success", |
| 1546 | + "x": np.array([1.0]), |
| 1547 | + } |
| 1548 | + |
| 1549 | + with ( |
| 1550 | + patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar), |
| 1551 | + patch("scipy.optimize.minimize", side_effect=mock_minimize), |
| 1552 | + ): |
| 1553 | + with self.assertRaises(Exception) as context: |
| 1554 | + balance_cbps.cbps( |
| 1555 | + sample_df, |
| 1556 | + sample_weights, |
| 1557 | + target_df, |
| 1558 | + target_weights, |
| 1559 | + transformations=None, |
| 1560 | + cbps_method="exact", |
| 1561 | + ) |
| 1562 | + |
| 1563 | + self.assertIn( |
| 1564 | + "no solution satisfying the constraints", |
| 1565 | + str(context.exception).lower(), |
| 1566 | + msg="Expected exception about constraint violation", |
| 1567 | + ) |
| 1568 | + |
| 1569 | + def test_over_method_both_gmm_constraint_violation_exception(self) -> None: |
| 1570 | + """Test line 778: Exception when over method both GMM optimizations fail with constraint violation. |
| 1571 | +
|
| 1572 | + Uses mocking to simulate both gmm_loss optimizations failing with constraint messages. |
| 1573 | + """ |
| 1574 | + from unittest.mock import patch |
| 1575 | + |
| 1576 | + sample_df, sample_weights, target_df, target_weights = ( |
| 1577 | + self._create_simple_test_data() |
| 1578 | + ) |
| 1579 | + |
| 1580 | + call_count: List[int] = [0] |
| 1581 | + |
| 1582 | + def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock: |
| 1583 | + call_count[0] += 1 |
| 1584 | + result = MagicMock() |
| 1585 | + if call_count[0] == 1: |
| 1586 | + # First call is balance_optimize - succeed |
| 1587 | + result.__getitem__ = lambda _, key: { |
| 1588 | + "success": np.bool_(True), |
| 1589 | + "message": "Success", |
| 1590 | + "x": x0, |
| 1591 | + "fun": 1.0, |
| 1592 | + }[key] |
| 1593 | + else: |
| 1594 | + # Both GMM optimizations fail with constraint violation |
| 1595 | + result.__getitem__ = lambda _, key: { |
| 1596 | + "success": np.bool_(False), |
| 1597 | + "message": "Did not converge to a solution satisfying the constraints", |
| 1598 | + "x": x0, |
| 1599 | + "fun": 100.0, |
| 1600 | + }[key] |
| 1601 | + return result |
| 1602 | + |
| 1603 | + def mock_minimize_scalar( |
| 1604 | + fun: Callable[..., Any], **kwargs: Any |
| 1605 | + ) -> Dict[str, Union[np.bool_, np.ndarray, str]]: |
| 1606 | + return { |
| 1607 | + "success": np.bool_(True), |
| 1608 | + "message": "Success", |
| 1609 | + "x": np.array([1.0]), |
| 1610 | + } |
| 1611 | + |
| 1612 | + with ( |
| 1613 | + patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar), |
| 1614 | + patch("scipy.optimize.minimize", side_effect=mock_minimize), |
| 1615 | + ): |
| 1616 | + with self.assertRaises(Exception) as context: |
| 1617 | + balance_cbps.cbps( |
| 1618 | + sample_df, |
| 1619 | + sample_weights, |
| 1620 | + target_df, |
| 1621 | + target_weights, |
| 1622 | + transformations=None, |
| 1623 | + cbps_method="over", |
| 1624 | + ) |
| 1625 | + |
| 1626 | + self.assertIn( |
| 1627 | + "no solution satisfying the constraints", |
| 1628 | + str(context.exception).lower(), |
| 1629 | + msg="Expected exception about constraint violation in over method", |
| 1630 | + ) |
0 commit comments