Skip to content

Commit 6902e37

Browse files
talgalilifacebook-github-bot
authored andcommitted
Increase balance package test coverage (#281)
Summary: The coverage report showed 98% overall coverage with 7 files having gaps. These tests cover previously untested edge cases including: - CLI exception handling paths for weighting failures - Sample class design effect diagnostics and IPW model parameters - CBPS optimization convergence warnings and constraint violation exceptions - Plotting functions with missing values, default parameters, and various dist_types - Distance metrics with empty numeric columns Differential Revision: D90946146
1 parent 6417fb0 commit 6902e37

File tree

8 files changed

+1533
-88
lines changed

8 files changed

+1533
-88
lines changed

.github/workflows/build-and-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ jobs:
5353
flake8 .
5454
- name: ufmt (formatting check)
5555
run: |
56+
echo "Checking for formatting issues..."
57+
ufmt diff .
5658
ufmt check .
5759
5860
pyre:

balance/stats_and_plots/weighted_comparisons_plots.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,6 @@ def seaborn_plot_dist(
653653
# With limiting the y axis range to (0,1)
654654
seaborn_plot_dist(dfs1, names=["self", "unadjusted", "target"], dist_type = "kde", ylim = (0,1))
655655
"""
656-
# Provide default names if not specified
657-
if names is None:
658-
names = [f"df_{i}" for i in range(len(dfs))]
659-
660656
# Set default dist_type
661657
dist_type_resolved: Literal["qq", "hist", "kde", "ecdf"]
662658
if dist_type is None:
@@ -671,10 +667,6 @@ def seaborn_plot_dist(
671667
if names is None:
672668
names = [f"df_{i}" for i in range(len(dfs))]
673669

674-
# Type narrowing for names parameter
675-
if names is None:
676-
names = []
677-
678670
# Choose set of variables to plot
679671
variables = choose_variables(*(d["df"] for d in dfs), variables=variables)
680672
logger.debug(f"plotting variables {variables}")
@@ -1348,12 +1340,13 @@ def naming_legend(object_name: str, names_of_dfs: List[str]) -> str:
13481340
naming_legend('self', ['self', 'target']) #'sample'
13491341
naming_legend('other_name', ['self', 'target']) #'other_name'
13501342
"""
1351-
if object_name in names_of_dfs:
1352-
return {
1353-
"unadjusted": "sample",
1354-
"self": "adjusted" if "unadjusted" in names_of_dfs else "sample",
1355-
"target": "population",
1356-
}[object_name]
1343+
name_mapping = {
1344+
"unadjusted": "sample",
1345+
"self": "adjusted" if "unadjusted" in names_of_dfs else "sample",
1346+
"target": "population",
1347+
}
1348+
if object_name in name_mapping:
1349+
return name_mapping[object_name]
13571350
else:
13581351
return object_name
13591352

tests/test_cbps.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
)
1515

1616
import warnings
17+
from typing import Any, Callable, Dict, List, Tuple, Union
18+
from unittest.mock import MagicMock
1719

1820
import balance.testutil
1921
import numpy as np
@@ -1352,6 +1354,8 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None:
13521354
transformations=None,
13531355
cbps_method="exact",
13541356
)
1357+
1358+
def test_cbps_over_method_logs_warnings(self) -> None:
13551359
"""Test CBPS over method logs warnings when optimization fails (lines 713, 747, 765).
13561360
13571361
Verifies that when optimization algorithms fail to converge, appropriate
@@ -1477,3 +1481,136 @@ def test_cbps_alpha_function_convergence_warning(self) -> None:
14771481
),
14781482
msg=f"Expected exception to contain convergence-related message, got: {e}",
14791483
)
1484+
1485+
1486+
class TestCbpsOptimizationConvergenceWithMocking(balance.testutil.BalanceTestCase):
1487+
"""Test CBPS optimization convergence warning branches using mocking (lines 689, 713, 726, 747, 765, 778).
1488+
1489+
These tests use unittest.mock to directly control scipy.optimize.minimize return values,
1490+
ensuring the specific warning and exception branches in _cbps_optimization are executed.
1491+
"""
1492+
1493+
def _create_simple_test_data(
1494+
self,
1495+
) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]:
1496+
"""Create simple test data for CBPS testing."""
1497+
sample_df = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1498+
target_df = pd.DataFrame({"a": [2.0, 3.0, 4.0, 5.0, 6.0]})
1499+
sample_weights = pd.Series([1.0] * 5)
1500+
target_weights = pd.Series([1.0] * 5)
1501+
return sample_df, sample_weights, target_df, target_weights
1502+
1503+
def test_exact_method_constraint_violation_exception(self) -> None:
1504+
"""Test line 726: Exception when exact method constraints can't be satisfied.
1505+
1506+
Uses mocking to simulate scipy.optimize.minimize returning success=False
1507+
with a specific constraint violation message.
1508+
"""
1509+
from unittest.mock import patch
1510+
1511+
sample_df, sample_weights, target_df, target_weights = (
1512+
self._create_simple_test_data()
1513+
)
1514+
1515+
def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock:
1516+
result = MagicMock()
1517+
# Simulate constraint violation failure
1518+
result.__getitem__ = lambda _, key: {
1519+
"success": np.bool_(False),
1520+
"message": "Did not converge to a solution satisfying the constraints",
1521+
"x": x0,
1522+
"fun": 100.0,
1523+
}[key]
1524+
return result
1525+
1526+
def mock_minimize_scalar(
1527+
fun: Callable[..., Any], **kwargs: Any
1528+
) -> Dict[str, Union[np.bool_, np.ndarray, str]]:
1529+
return {
1530+
"success": np.bool_(True),
1531+
"message": "Success",
1532+
"x": np.array([1.0]),
1533+
}
1534+
1535+
with (
1536+
patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar),
1537+
patch("scipy.optimize.minimize", side_effect=mock_minimize),
1538+
):
1539+
with self.assertRaises(Exception) as context:
1540+
balance_cbps.cbps(
1541+
sample_df,
1542+
sample_weights,
1543+
target_df,
1544+
target_weights,
1545+
transformations=None,
1546+
cbps_method="exact",
1547+
)
1548+
1549+
self.assertIn(
1550+
"no solution satisfying the constraints",
1551+
str(context.exception).lower(),
1552+
msg="Expected exception about constraint violation",
1553+
)
1554+
1555+
def test_over_method_both_gmm_constraint_violation_exception(self) -> None:
1556+
"""Test line 778: Exception when over method both GMM optimizations fail with constraint violation.
1557+
1558+
Uses mocking to simulate both gmm_loss optimizations failing with constraint messages.
1559+
"""
1560+
from unittest.mock import patch
1561+
1562+
sample_df, sample_weights, target_df, target_weights = (
1563+
self._create_simple_test_data()
1564+
)
1565+
1566+
call_count: List[int] = [0]
1567+
1568+
def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock:
1569+
call_count[0] += 1
1570+
result = MagicMock()
1571+
if call_count[0] == 1:
1572+
# First call is balance_optimize - succeed
1573+
result.__getitem__ = lambda _, key: {
1574+
"success": np.bool_(True),
1575+
"message": "Success",
1576+
"x": x0,
1577+
"fun": 1.0,
1578+
}[key]
1579+
else:
1580+
# Both GMM optimizations fail with constraint violation
1581+
result.__getitem__ = lambda _, key: {
1582+
"success": np.bool_(False),
1583+
"message": "Did not converge to a solution satisfying the constraints",
1584+
"x": x0,
1585+
"fun": 100.0,
1586+
}[key]
1587+
return result
1588+
1589+
def mock_minimize_scalar(
1590+
fun: Callable[..., Any], **kwargs: Any
1591+
) -> Dict[str, Union[np.bool_, np.ndarray, str]]:
1592+
return {
1593+
"success": np.bool_(True),
1594+
"message": "Success",
1595+
"x": np.array([1.0]),
1596+
}
1597+
1598+
with (
1599+
patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar),
1600+
patch("scipy.optimize.minimize", side_effect=mock_minimize),
1601+
):
1602+
with self.assertRaises(Exception) as context:
1603+
balance_cbps.cbps(
1604+
sample_df,
1605+
sample_weights,
1606+
target_df,
1607+
target_weights,
1608+
transformations=None,
1609+
cbps_method="over",
1610+
)
1611+
1612+
self.assertIn(
1613+
"no solution satisfying the constraints",
1614+
str(context.exception).lower(),
1615+
msg="Expected exception about constraint violation in over method",
1616+
)

0 commit comments

Comments
 (0)