Skip to content

Commit 5397433

Browse files
talgalilifacebook-github-bot
authored andcommitted
Increase test coverage from 98% to near 100%
Summary: Adds comprehensive tests covering previously untested edge cases identified in the coverage report. The 7 files with coverage gaps now have tests for: - **CLI**: Exception handling paths for weighting failures - **Sample class**: Design effect diagnostics and IPW model parameters - **CBPS**: Optimization convergence warnings and constraint violation exceptions - **Plotting**: Functions with missing values, default parameters, and various dist_types - **Distance metrics**: Empty numeric columns edge case This improves the overall test coverage reliability and catches potential regressions in error handling paths. Differential Revision: D90946146
1 parent 6417fb0 commit 5397433

File tree

9 files changed

+1633
-139
lines changed

9 files changed

+1633
-139
lines changed

.github/workflows/build-and-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ jobs:
5353
flake8 .
5454
- name: ufmt (formatting check)
5555
run: |
56+
echo "Checking for formatting issues..."
57+
ufmt diff .
5658
ufmt check .
5759
5860
pyre:

balance/stats_and_plots/weighted_comparisons_plots.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,6 @@ def seaborn_plot_dist(
653653
# With limiting the y axis range to (0,1)
654654
seaborn_plot_dist(dfs1, names=["self", "unadjusted", "target"], dist_type = "kde", ylim = (0,1))
655655
"""
656-
# Provide default names if not specified
657-
if names is None:
658-
names = [f"df_{i}" for i in range(len(dfs))]
659-
660656
# Set default dist_type
661657
dist_type_resolved: Literal["qq", "hist", "kde", "ecdf"]
662658
if dist_type is None:
@@ -671,10 +667,6 @@ def seaborn_plot_dist(
671667
if names is None:
672668
names = [f"df_{i}" for i in range(len(dfs))]
673669

674-
# Type narrowing for names parameter
675-
if names is None:
676-
names = []
677-
678670
# Choose set of variables to plot
679671
variables = choose_variables(*(d["df"] for d in dfs), variables=variables)
680672
logger.debug(f"plotting variables {variables}")
@@ -1348,12 +1340,13 @@ def naming_legend(object_name: str, names_of_dfs: List[str]) -> str:
13481340
naming_legend('self', ['self', 'target']) #'sample'
13491341
naming_legend('other_name', ['self', 'target']) #'other_name'
13501342
"""
1351-
if object_name in names_of_dfs:
1352-
return {
1353-
"unadjusted": "sample",
1354-
"self": "adjusted" if "unadjusted" in names_of_dfs else "sample",
1355-
"target": "population",
1356-
}[object_name]
1343+
name_mapping = {
1344+
"unadjusted": "sample",
1345+
"self": "adjusted" if "unadjusted" in names_of_dfs else "sample",
1346+
"target": "population",
1347+
}
1348+
if object_name in name_mapping:
1349+
return name_mapping[object_name]
13571350
else:
13581351
return object_name
13591352

tests/test_cbps.py

Lines changed: 192 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
)
1515

1616
import warnings
17+
from typing import Any, Callable, Dict, List, Tuple, Union
18+
from unittest.mock import MagicMock
1719

1820
import balance.testutil
1921
import numpy as np
@@ -1342,16 +1344,25 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None:
13421344

13431345
# CBPS handles extreme cases by logging a warning about identical weights
13441346
# rather than raising an exception
1345-
self.assertWarnsRegexp(
1346-
"All weights are identical",
1347-
balance_cbps.cbps,
1348-
sample_df,
1349-
sample_weights,
1350-
target_df,
1351-
target_weights,
1352-
transformations=None,
1353-
cbps_method="exact",
1354-
)
1347+
# Suppress PerfectSeparationWarning as it's expected with this extreme test data
1348+
with warnings.catch_warnings():
1349+
warnings.filterwarnings(
1350+
"ignore",
1351+
message="Perfect separation or prediction detected",
1352+
category=PerfectSeparationWarning,
1353+
)
1354+
self.assertWarnsRegexp(
1355+
"All weights are identical",
1356+
balance_cbps.cbps,
1357+
sample_df,
1358+
sample_weights,
1359+
target_df,
1360+
target_weights,
1361+
transformations=None,
1362+
cbps_method="exact",
1363+
)
1364+
1365+
def test_cbps_over_method_logs_warnings(self) -> None:
13551366
"""Test CBPS over method logs warnings when optimization fails (lines 713, 747, 765).
13561367
13571368
Verifies that when optimization algorithms fail to converge, appropriate
@@ -1381,38 +1392,45 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None:
13811392
# Run with over method to exercise gmm optimization paths
13821393
# Use very tight opt_opts to force convergence failure
13831394
# We expect either warnings to be logged or an exception to be raised
1384-
try:
1385-
with self.assertLogs(level=logging.WARNING) as log_context:
1386-
balance_cbps.cbps(
1387-
sample_df,
1388-
sample_weights,
1389-
target_df,
1390-
target_weights,
1391-
transformations=None,
1392-
cbps_method="over",
1393-
opt_opts={"maxiter": 1}, # Force convergence failure
1394-
)
1395-
# Verify that at least one warning was logged
1396-
self.assertTrue(
1397-
len(log_context.records) > 0,
1398-
msg="Expected warning logs when optimization fails to converge",
1399-
)
1400-
except Exception as e:
1401-
# If an exception is raised, verify it contains relevant error info
1402-
error_msg = str(e).lower()
1403-
self.assertTrue(
1404-
any(
1405-
keyword in error_msg
1406-
for keyword in [
1407-
"converge",
1408-
"constraint",
1409-
"singular",
1410-
"optimization",
1411-
"failed",
1412-
]
1413-
),
1414-
msg=f"Expected exception to contain convergence-related message, got: {e}",
1395+
# Suppress PerfectSeparationWarning as it's expected with this extreme test data
1396+
with warnings.catch_warnings():
1397+
warnings.filterwarnings(
1398+
"ignore",
1399+
message="Perfect separation or prediction detected",
1400+
category=PerfectSeparationWarning,
14151401
)
1402+
try:
1403+
with self.assertLogs(level=logging.WARNING) as log_context:
1404+
balance_cbps.cbps(
1405+
sample_df,
1406+
sample_weights,
1407+
target_df,
1408+
target_weights,
1409+
transformations=None,
1410+
cbps_method="over",
1411+
opt_opts={"maxiter": 1}, # Force convergence failure
1412+
)
1413+
# Verify that at least one warning was logged
1414+
self.assertTrue(
1415+
len(log_context.records) > 0,
1416+
msg="Expected warning logs when optimization fails to converge",
1417+
)
1418+
except Exception as e:
1419+
# If an exception is raised, verify it contains relevant error info
1420+
error_msg = str(e).lower()
1421+
self.assertTrue(
1422+
any(
1423+
keyword in error_msg
1424+
for keyword in [
1425+
"converge",
1426+
"constraint",
1427+
"singular",
1428+
"optimization",
1429+
"failed",
1430+
]
1431+
),
1432+
msg=f"Expected exception to contain convergence-related message, got: {e}",
1433+
)
14161434

14171435
def test_cbps_alpha_function_convergence_warning(self) -> None:
14181436
"""Test CBPS logs warning when alpha_function fails to converge (line 689).
@@ -1477,3 +1495,136 @@ def test_cbps_alpha_function_convergence_warning(self) -> None:
14771495
),
14781496
msg=f"Expected exception to contain convergence-related message, got: {e}",
14791497
)
1498+
1499+
1500+
class TestCbpsOptimizationConvergenceWithMocking(balance.testutil.BalanceTestCase):
1501+
"""Test CBPS optimization convergence warning branches using mocking (lines 689, 713, 726, 747, 765, 778).
1502+
1503+
These tests use unittest.mock to directly control scipy.optimize.minimize return values,
1504+
ensuring the specific warning and exception branches in _cbps_optimization are executed.
1505+
"""
1506+
1507+
def _create_simple_test_data(
1508+
self,
1509+
) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]:
1510+
"""Create simple test data for CBPS testing."""
1511+
sample_df = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1512+
target_df = pd.DataFrame({"a": [2.0, 3.0, 4.0, 5.0, 6.0]})
1513+
sample_weights = pd.Series([1.0] * 5)
1514+
target_weights = pd.Series([1.0] * 5)
1515+
return sample_df, sample_weights, target_df, target_weights
1516+
1517+
def test_exact_method_constraint_violation_exception(self) -> None:
1518+
"""Test line 726: Exception when exact method constraints can't be satisfied.
1519+
1520+
Uses mocking to simulate scipy.optimize.minimize returning success=False
1521+
with a specific constraint violation message.
1522+
"""
1523+
from unittest.mock import patch
1524+
1525+
sample_df, sample_weights, target_df, target_weights = (
1526+
self._create_simple_test_data()
1527+
)
1528+
1529+
def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock:
1530+
result = MagicMock()
1531+
# Simulate constraint violation failure
1532+
result.__getitem__ = lambda _, key: {
1533+
"success": np.bool_(False),
1534+
"message": "Did not converge to a solution satisfying the constraints",
1535+
"x": x0,
1536+
"fun": 100.0,
1537+
}[key]
1538+
return result
1539+
1540+
def mock_minimize_scalar(
1541+
fun: Callable[..., Any], **kwargs: Any
1542+
) -> Dict[str, Union[np.bool_, np.ndarray, str]]:
1543+
return {
1544+
"success": np.bool_(True),
1545+
"message": "Success",
1546+
"x": np.array([1.0]),
1547+
}
1548+
1549+
with (
1550+
patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar),
1551+
patch("scipy.optimize.minimize", side_effect=mock_minimize),
1552+
):
1553+
with self.assertRaises(Exception) as context:
1554+
balance_cbps.cbps(
1555+
sample_df,
1556+
sample_weights,
1557+
target_df,
1558+
target_weights,
1559+
transformations=None,
1560+
cbps_method="exact",
1561+
)
1562+
1563+
self.assertIn(
1564+
"no solution satisfying the constraints",
1565+
str(context.exception).lower(),
1566+
msg="Expected exception about constraint violation",
1567+
)
1568+
1569+
def test_over_method_both_gmm_constraint_violation_exception(self) -> None:
1570+
"""Test line 778: Exception when over method both GMM optimizations fail with constraint violation.
1571+
1572+
Uses mocking to simulate both gmm_loss optimizations failing with constraint messages.
1573+
"""
1574+
from unittest.mock import patch
1575+
1576+
sample_df, sample_weights, target_df, target_weights = (
1577+
self._create_simple_test_data()
1578+
)
1579+
1580+
call_count: List[int] = [0]
1581+
1582+
def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock:
1583+
call_count[0] += 1
1584+
result = MagicMock()
1585+
if call_count[0] == 1:
1586+
# First call is balance_optimize - succeed
1587+
result.__getitem__ = lambda _, key: {
1588+
"success": np.bool_(True),
1589+
"message": "Success",
1590+
"x": x0,
1591+
"fun": 1.0,
1592+
}[key]
1593+
else:
1594+
# Both GMM optimizations fail with constraint violation
1595+
result.__getitem__ = lambda _, key: {
1596+
"success": np.bool_(False),
1597+
"message": "Did not converge to a solution satisfying the constraints",
1598+
"x": x0,
1599+
"fun": 100.0,
1600+
}[key]
1601+
return result
1602+
1603+
def mock_minimize_scalar(
1604+
fun: Callable[..., Any], **kwargs: Any
1605+
) -> Dict[str, Union[np.bool_, np.ndarray, str]]:
1606+
return {
1607+
"success": np.bool_(True),
1608+
"message": "Success",
1609+
"x": np.array([1.0]),
1610+
}
1611+
1612+
with (
1613+
patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar),
1614+
patch("scipy.optimize.minimize", side_effect=mock_minimize),
1615+
):
1616+
with self.assertRaises(Exception) as context:
1617+
balance_cbps.cbps(
1618+
sample_df,
1619+
sample_weights,
1620+
target_df,
1621+
target_weights,
1622+
transformations=None,
1623+
cbps_method="over",
1624+
)
1625+
1626+
self.assertIn(
1627+
"no solution satisfying the constraints",
1628+
str(context.exception).lower(),
1629+
msg="Expected exception about constraint violation in over method",
1630+
)

0 commit comments

Comments
 (0)