|
14 | 14 | ) |
15 | 15 |
|
16 | 16 | import warnings |
| 17 | +from typing import Any, Callable, Dict, List, Tuple, Union |
| 18 | +from unittest.mock import MagicMock |
17 | 19 |
|
18 | 20 | import balance.testutil |
19 | 21 | import numpy as np |
@@ -1352,6 +1354,8 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None: |
1352 | 1354 | transformations=None, |
1353 | 1355 | cbps_method="exact", |
1354 | 1356 | ) |
| 1357 | + |
| 1358 | + def test_cbps_over_method_logs_warnings(self) -> None: |
1355 | 1359 | """Test CBPS over method logs warnings when optimization fails (lines 713, 747, 765). |
1356 | 1360 |
|
1357 | 1361 | Verifies that when optimization algorithms fail to converge, appropriate |
@@ -1477,3 +1481,136 @@ def test_cbps_alpha_function_convergence_warning(self) -> None: |
1477 | 1481 | ), |
1478 | 1482 | msg=f"Expected exception to contain convergence-related message, got: {e}", |
1479 | 1483 | ) |
| 1484 | + |
| 1485 | + |
| 1486 | +class TestCbpsOptimizationConvergenceWithMocking(balance.testutil.BalanceTestCase): |
| 1487 | + """Test CBPS optimization convergence warning branches using mocking (lines 689, 713, 726, 747, 765, 778). |
| 1488 | +
|
| 1489 | + These tests use unittest.mock to directly control scipy.optimize.minimize return values, |
| 1490 | + ensuring the specific warning and exception branches in _cbps_optimization are executed. |
| 1491 | + """ |
| 1492 | + |
| 1493 | + def _create_simple_test_data( |
| 1494 | + self, |
| 1495 | + ) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]: |
| 1496 | + """Create simple test data for CBPS testing.""" |
| 1497 | + sample_df = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) |
| 1498 | + target_df = pd.DataFrame({"a": [2.0, 3.0, 4.0, 5.0, 6.0]}) |
| 1499 | + sample_weights = pd.Series([1.0] * 5) |
| 1500 | + target_weights = pd.Series([1.0] * 5) |
| 1501 | + return sample_df, sample_weights, target_df, target_weights |
| 1502 | + |
| 1503 | + def test_exact_method_constraint_violation_exception(self) -> None: |
| 1504 | + """Test line 726: Exception when exact method constraints can't be satisfied. |
| 1505 | +
|
| 1506 | + Uses mocking to simulate scipy.optimize.minimize returning success=False |
| 1507 | + with a specific constraint violation message. |
| 1508 | + """ |
| 1509 | + from unittest.mock import patch |
| 1510 | + |
| 1511 | + sample_df, sample_weights, target_df, target_weights = ( |
| 1512 | + self._create_simple_test_data() |
| 1513 | + ) |
| 1514 | + |
| 1515 | + def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock: |
| 1516 | + result = MagicMock() |
| 1517 | + # Simulate constraint violation failure |
| 1518 | + result.__getitem__ = lambda _, key: { |
| 1519 | + "success": np.bool_(False), |
| 1520 | + "message": "Did not converge to a solution satisfying the constraints", |
| 1521 | + "x": x0, |
| 1522 | + "fun": 100.0, |
| 1523 | + }[key] |
| 1524 | + return result |
| 1525 | + |
| 1526 | + def mock_minimize_scalar( |
| 1527 | + fun: Callable[..., Any], **kwargs: Any |
| 1528 | + ) -> Dict[str, Union[np.bool_, np.ndarray, str]]: |
| 1529 | + return { |
| 1530 | + "success": np.bool_(True), |
| 1531 | + "message": "Success", |
| 1532 | + "x": np.array([1.0]), |
| 1533 | + } |
| 1534 | + |
| 1535 | + with ( |
| 1536 | + patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar), |
| 1537 | + patch("scipy.optimize.minimize", side_effect=mock_minimize), |
| 1538 | + ): |
| 1539 | + with self.assertRaises(Exception) as context: |
| 1540 | + balance_cbps.cbps( |
| 1541 | + sample_df, |
| 1542 | + sample_weights, |
| 1543 | + target_df, |
| 1544 | + target_weights, |
| 1545 | + transformations=None, |
| 1546 | + cbps_method="exact", |
| 1547 | + ) |
| 1548 | + |
| 1549 | + self.assertIn( |
| 1550 | + "no solution satisfying the constraints", |
| 1551 | + str(context.exception).lower(), |
| 1552 | + msg="Expected exception about constraint violation", |
| 1553 | + ) |
| 1554 | + |
| 1555 | + def test_over_method_both_gmm_constraint_violation_exception(self) -> None: |
| 1556 | + """Test line 778: Exception when over method both GMM optimizations fail with constraint violation. |
| 1557 | +
|
| 1558 | + Uses mocking to simulate both gmm_loss optimizations failing with constraint messages. |
| 1559 | + """ |
| 1560 | + from unittest.mock import patch |
| 1561 | + |
| 1562 | + sample_df, sample_weights, target_df, target_weights = ( |
| 1563 | + self._create_simple_test_data() |
| 1564 | + ) |
| 1565 | + |
| 1566 | + call_count: List[int] = [0] |
| 1567 | + |
| 1568 | + def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock: |
| 1569 | + call_count[0] += 1 |
| 1570 | + result = MagicMock() |
| 1571 | + if call_count[0] == 1: |
| 1572 | + # First call is balance_optimize - succeed |
| 1573 | + result.__getitem__ = lambda _, key: { |
| 1574 | + "success": np.bool_(True), |
| 1575 | + "message": "Success", |
| 1576 | + "x": x0, |
| 1577 | + "fun": 1.0, |
| 1578 | + }[key] |
| 1579 | + else: |
| 1580 | + # Both GMM optimizations fail with constraint violation |
| 1581 | + result.__getitem__ = lambda _, key: { |
| 1582 | + "success": np.bool_(False), |
| 1583 | + "message": "Did not converge to a solution satisfying the constraints", |
| 1584 | + "x": x0, |
| 1585 | + "fun": 100.0, |
| 1586 | + }[key] |
| 1587 | + return result |
| 1588 | + |
| 1589 | + def mock_minimize_scalar( |
| 1590 | + fun: Callable[..., Any], **kwargs: Any |
| 1591 | + ) -> Dict[str, Union[np.bool_, np.ndarray, str]]: |
| 1592 | + return { |
| 1593 | + "success": np.bool_(True), |
| 1594 | + "message": "Success", |
| 1595 | + "x": np.array([1.0]), |
| 1596 | + } |
| 1597 | + |
| 1598 | + with ( |
| 1599 | + patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar), |
| 1600 | + patch("scipy.optimize.minimize", side_effect=mock_minimize), |
| 1601 | + ): |
| 1602 | + with self.assertRaises(Exception) as context: |
| 1603 | + balance_cbps.cbps( |
| 1604 | + sample_df, |
| 1605 | + sample_weights, |
| 1606 | + target_df, |
| 1607 | + target_weights, |
| 1608 | + transformations=None, |
| 1609 | + cbps_method="over", |
| 1610 | + ) |
| 1611 | + |
| 1612 | + self.assertIn( |
| 1613 | + "no solution satisfying the constraints", |
| 1614 | + str(context.exception).lower(), |
| 1615 | + msg="Expected exception about constraint violation in over method", |
| 1616 | + ) |
0 commit comments