Skip to content

Commit c7fcb79

Browse files
Copilotmeta-codesync[bot]
authored andcommitted
Extract common pattern in comparison stat functions to reduce CCN (#267)
Summary: - [x] Create a generic helper function `_apply_comparison_stat_to_BalanceDF` to reduce code duplication - [x] Update `_asmd_BalanceDF` to use the new helper function - [x] Update `_kld_BalanceDF` to use the new helper function - [x] Update `_emd_BalanceDF` to use the new helper function - [x] Update `_cvmd_BalanceDF` to use the new helper function - [x] Update `_ks_BalanceDF` to use the new helper function - [x] Run tests to ensure all functionality remains unchanged (all 83 tests in test_balancedf.py pass) - [x] Run code review and security checks (no issues found) - [x] Add direct test coverage for _kld_BalanceDF, _emd_BalanceDF, _cvmd_BalanceDF, and _ks_BalanceDF (5 new tests added) - [x] Fix flake8 linting errors (removed trailing whitespace from blank lines) - [x] Fix ufmt formatting errors (formatted test file with ufmt) Successfully extracted the common pattern from five similar functions (`_asmd_BalanceDF`, `_kld_BalanceDF`, `_emd_BalanceDF`, `_cvmd_BalanceDF`, `_ks_BalanceDF`) into a single generic helper function `_apply_comparison_stat_to_BalanceDF`. ### Changes Made - Added `Callable` to the imports from `typing` module - Created `_apply_comparison_stat_to_BalanceDF` helper function that: 1. Validates inputs are BalanceDF objects 2. Extracts df and weights from both objects 3. Calls the comparison function with the extracted data - Refactored all five comparison functions to use the helper (reduced from ~30 lines to ~5 lines each) - Maintains special handling for `_asmd_BalanceDF` which passes `std_type="target"` via kwargs - **Added comprehensive test coverage** for all comparison methods: - `test_BalanceDF__kld_BalanceDF`: Direct test of _kld_BalanceDF method - `test_BalanceDF__emd_BalanceDF`: Direct test of _emd_BalanceDF method - `test_BalanceDF__cvmd_BalanceDF`: Direct test of _cvmd_BalanceDF method - `test_BalanceDF__ks_BalanceDF`: Direct test of _ks_BalanceDF method - `test_BalanceDF_comparison_functions_invalid_input`: Tests input validation for all methods - **Fixed flake8 linting errors**: Removed trailing whitespace from blank lines in test file - **Fixed ufmt formatting**: Formatted test file according to project standards (black + usort) ### Test Coverage - All tests now directly exercise the helper function through the four comparison methods - Tests verify correct Series output with expected keys (a, b, mean(metric)) - Tests verify mathematical properties (non-negativity, bounded ranges) - Tests verify aggregate_by_main_covar parameter works - Tests verify proper input validation with clear error messages - Total tests: 88 (83 original + 5 new), all passing - **Code quality compliance**: All linting (flake8) and formatting (ufmt) checks pass ### Benefits - **Reduces Cyclomatic Complexity Number (CCN)** - the original goal of the issue - **Eliminates code duplication** - DRY principle applied - **Easier maintenance** - future changes only need to be made in one place - **Type safety** - added proper type hint for the callable parameter - **No behavioral changes** - all 83 existing tests pass without modification - **Comprehensive test coverage** - direct tests for all comparison methods and edge cases - **Code quality** - passes all linting and formatting checks <details> <summary>Original prompt</summary> > > ---- > > *This section details on the original issue you should resolve* > > <issue_title>[FEATURE] generalize functions in ‎balance/balancedf_class.py‎</issue_title> > <issue_description>a bunch of functions in > ‎balance/balancedf_class.py‎ > > Follow the exact same pattern (other than one word change: > staticmethod > def _emd_BalanceDF( > sample_BalanceDF: "BalanceDF", > target_BalanceDF: "BalanceDF", > aggregate_by_main_covar: bool = False, > ) -> pd.Series: > """Run EMD on two BalanceDF objects. > > Prepares the BalanceDF objects by passing them through :func:`_get_df_and_weights`, and > then passes the df and weights into :func:`weighted_comparisons_stats.emd`. > > Args: > sample_BalanceDF (BalanceDF): Object. > target_BalanceDF (BalanceDF): Object. > aggregate_by_main_covar (bool, optional): See :func:`weighted_comparisons_stats.emd`. Defaults to False. > > Returns: > pd.Series: See :func:`weighted_comparisons_stats.emd`. > """ > BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF") > BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF") > > sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights() > target_df_values, target_weights = target_BalanceDF._get_df_and_weights() > > return weighted_comparisons_stats.emd( > sample_df_values, > target_df_values, > sample_weights, > target_weights, > aggregate_by_main_covar=aggregate_by_main_covar, > ) > > > > Extract this pattern to a helper function to reduce CCN. > > </issue_description> > > ## Comments on the Issue (you are copilot in this section) > > <comments> > </comments> > </details> - Fixes #266 --- 💬 We'd love your input! Share your thoughts on Copilot coding agent in our [2 minute survey](https://gh.io/copilot-coding-agent-survey). Pull Request resolved: #267 Differential Revision: D90870323 Pulled By: talgalili fbshipit-source-id: 064e3878506e1f65726ce1302fc692e9ec794676
1 parent c17914a commit c7fcb79

File tree

2 files changed

+200
-61
lines changed

2 files changed

+200
-61
lines changed

balance/balancedf_class.py

Lines changed: 70 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99

1010
import logging
11-
from typing import Any, Dict, Literal, Tuple
11+
from typing import Any, Callable, Dict, Literal, Tuple
1212

1313
import numpy as np
1414
import numpy.typing as npt
@@ -1110,6 +1110,50 @@ def _get_df_and_weights(
11101110
weights = self._weights.values if (self._weights is not None) else None
11111111
return df_model_matrix, weights
11121112

1113+
@staticmethod
1114+
def _apply_comparison_stat_to_BalanceDF(
1115+
comparison_func: Callable[..., pd.Series],
1116+
sample_BalanceDF: "BalanceDF",
1117+
target_BalanceDF: "BalanceDF",
1118+
aggregate_by_main_covar: bool = False,
1119+
**kwargs: Any,
1120+
) -> pd.Series:
1121+
"""Generic helper to apply a weighted comparison statistic function to two BalanceDF objects.
1122+
1123+
This helper function reduces code duplication across multiple comparison methods
1124+
(asmd, kld, emd, cvmd, ks) by extracting the common pattern of:
1125+
1. Validating inputs are BalanceDF objects
1126+
2. Extracting df and weights from both objects
1127+
3. Calling the comparison function with the extracted data
1128+
1129+
Args:
1130+
comparison_func (Callable[..., pd.Series]): The comparison function from
1131+
weighted_comparisons_stats to apply (e.g., asmd, kld, emd, cvmd, ks).
1132+
sample_BalanceDF (BalanceDF): Sample object.
1133+
target_BalanceDF (BalanceDF): Target object.
1134+
aggregate_by_main_covar (bool, optional): Whether to aggregate by main covariate.
1135+
Defaults to False. Passed to the comparison function.
1136+
**kwargs: Additional keyword arguments to pass to the comparison function
1137+
(e.g., std_type for asmd).
1138+
1139+
Returns:
1140+
pd.Series: The result from the comparison function.
1141+
"""
1142+
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
1143+
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")
1144+
1145+
sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
1146+
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()
1147+
1148+
return comparison_func(
1149+
sample_df_values,
1150+
target_df_values,
1151+
sample_weights,
1152+
target_weights,
1153+
aggregate_by_main_covar=aggregate_by_main_covar,
1154+
**kwargs,
1155+
)
1156+
11131157
@staticmethod
11141158
def _asmd_BalanceDF(
11151159
sample_BalanceDF: "BalanceDF",
@@ -1156,19 +1200,12 @@ def _asmd_BalanceDF(
11561200
# mean(asmd) 1.756543
11571201
# dtype: float64
11581202
"""
1159-
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
1160-
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "target_BalanceDF")
1161-
1162-
sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
1163-
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()
1164-
1165-
return weighted_comparisons_stats.asmd(
1166-
sample_df_values,
1167-
target_df_values,
1168-
sample_weights,
1169-
target_weights,
1203+
return BalanceDF._apply_comparison_stat_to_BalanceDF(
1204+
weighted_comparisons_stats.asmd,
1205+
sample_BalanceDF,
1206+
target_BalanceDF,
1207+
aggregate_by_main_covar,
11701208
std_type="target",
1171-
aggregate_by_main_covar=aggregate_by_main_covar,
11721209
)
11731210

11741211
@staticmethod
@@ -1190,18 +1227,11 @@ def _kld_BalanceDF(
11901227
Returns:
11911228
pd.Series: See :func:`weighted_comparisons_stats.kld`.
11921229
"""
1193-
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
1194-
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")
1195-
1196-
sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
1197-
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()
1198-
1199-
return weighted_comparisons_stats.kld(
1200-
sample_df_values,
1201-
target_df_values,
1202-
sample_weights,
1203-
target_weights,
1204-
aggregate_by_main_covar=aggregate_by_main_covar,
1230+
return BalanceDF._apply_comparison_stat_to_BalanceDF(
1231+
weighted_comparisons_stats.kld,
1232+
sample_BalanceDF,
1233+
target_BalanceDF,
1234+
aggregate_by_main_covar,
12051235
)
12061236

12071237
@staticmethod
@@ -1223,18 +1253,11 @@ def _emd_BalanceDF(
12231253
Returns:
12241254
pd.Series: See :func:`weighted_comparisons_stats.emd`.
12251255
"""
1226-
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
1227-
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")
1228-
1229-
sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
1230-
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()
1231-
1232-
return weighted_comparisons_stats.emd(
1233-
sample_df_values,
1234-
target_df_values,
1235-
sample_weights,
1236-
target_weights,
1237-
aggregate_by_main_covar=aggregate_by_main_covar,
1256+
return BalanceDF._apply_comparison_stat_to_BalanceDF(
1257+
weighted_comparisons_stats.emd,
1258+
sample_BalanceDF,
1259+
target_BalanceDF,
1260+
aggregate_by_main_covar,
12381261
)
12391262

12401263
@staticmethod
@@ -1256,18 +1279,11 @@ def _cvmd_BalanceDF(
12561279
Returns:
12571280
pd.Series: See :func:`weighted_comparisons_stats.cvmd`.
12581281
"""
1259-
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
1260-
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")
1261-
1262-
sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
1263-
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()
1264-
1265-
return weighted_comparisons_stats.cvmd(
1266-
sample_df_values,
1267-
target_df_values,
1268-
sample_weights,
1269-
target_weights,
1270-
aggregate_by_main_covar=aggregate_by_main_covar,
1282+
return BalanceDF._apply_comparison_stat_to_BalanceDF(
1283+
weighted_comparisons_stats.cvmd,
1284+
sample_BalanceDF,
1285+
target_BalanceDF,
1286+
aggregate_by_main_covar,
12711287
)
12721288

12731289
@staticmethod
@@ -1289,18 +1305,11 @@ def _ks_BalanceDF(
12891305
Returns:
12901306
pd.Series: See :func:`weighted_comparisons_stats.ks`.
12911307
"""
1292-
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
1293-
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")
1294-
1295-
sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
1296-
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()
1297-
1298-
return weighted_comparisons_stats.ks(
1299-
sample_df_values,
1300-
target_df_values,
1301-
sample_weights,
1302-
target_weights,
1303-
aggregate_by_main_covar=aggregate_by_main_covar,
1308+
return BalanceDF._apply_comparison_stat_to_BalanceDF(
1309+
weighted_comparisons_stats.ks,
1310+
sample_BalanceDF,
1311+
target_BalanceDF,
1312+
aggregate_by_main_covar,
13041313
)
13051314

13061315
def asmd(

tests/test_balancedf.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,136 @@ def test_BalanceDF_asmd_aggregate_by_main_covar(self) -> None:
13011301
self.assertEqual(outcome_default, expected_default)
13021302
self.assertEqual(outcome_main_covar, expected_main_covar)
13031303

1304+
def test_BalanceDF__kld_BalanceDF(self) -> None:
1305+
"""Test _kld_BalanceDF static method directly."""
1306+
sample = Sample.from_frame(
1307+
pd.DataFrame({"id": (1, 2), "a": (1, 2), "b": (-1, 12), "weight": (1, 2)})
1308+
).covars()
1309+
1310+
target = Sample.from_frame(
1311+
pd.DataFrame({"id": (1, 2), "a": (3, 4), "b": (0, 42), "weight": (1, 2)})
1312+
).covars()
1313+
1314+
result = BalanceDF._kld_BalanceDF(sample, target)
1315+
1316+
# Verify result is a Series with expected keys
1317+
self.assertIsInstance(result, pd.Series)
1318+
self.assertIn("a", result.index)
1319+
self.assertIn("b", result.index)
1320+
self.assertIn("mean(kld)", result.index)
1321+
1322+
# Verify all values are non-negative (KLD property)
1323+
self.assertTrue((result >= 0).all())
1324+
1325+
# Test with aggregate_by_main_covar
1326+
result_agg = BalanceDF._kld_BalanceDF(
1327+
sample, target, aggregate_by_main_covar=True
1328+
)
1329+
self.assertIsInstance(result_agg, pd.Series)
1330+
1331+
def test_BalanceDF__emd_BalanceDF(self) -> None:
1332+
"""Test _emd_BalanceDF static method directly."""
1333+
sample = Sample.from_frame(
1334+
pd.DataFrame({"id": (1, 2), "a": (1, 2), "b": (-1, 12), "weight": (1, 2)})
1335+
).covars()
1336+
1337+
target = Sample.from_frame(
1338+
pd.DataFrame({"id": (1, 2), "a": (3, 4), "b": (0, 42), "weight": (1, 2)})
1339+
).covars()
1340+
1341+
result = BalanceDF._emd_BalanceDF(sample, target)
1342+
1343+
# Verify result is a Series with expected keys
1344+
self.assertIsInstance(result, pd.Series)
1345+
self.assertIn("a", result.index)
1346+
self.assertIn("b", result.index)
1347+
self.assertIn("mean(emd)", result.index)
1348+
1349+
# Verify all values are non-negative (EMD property)
1350+
self.assertTrue((result >= 0).all())
1351+
1352+
# Test with aggregate_by_main_covar
1353+
result_agg = BalanceDF._emd_BalanceDF(
1354+
sample, target, aggregate_by_main_covar=True
1355+
)
1356+
self.assertIsInstance(result_agg, pd.Series)
1357+
1358+
def test_BalanceDF__cvmd_BalanceDF(self) -> None:
1359+
"""Test _cvmd_BalanceDF static method directly."""
1360+
sample = Sample.from_frame(
1361+
pd.DataFrame({"id": (1, 2), "a": (1, 2), "b": (-1, 12), "weight": (1, 2)})
1362+
).covars()
1363+
1364+
target = Sample.from_frame(
1365+
pd.DataFrame({"id": (1, 2), "a": (3, 4), "b": (0, 42), "weight": (1, 2)})
1366+
).covars()
1367+
1368+
result = BalanceDF._cvmd_BalanceDF(sample, target)
1369+
1370+
# Verify result is a Series with expected keys
1371+
self.assertIsInstance(result, pd.Series)
1372+
self.assertIn("a", result.index)
1373+
self.assertIn("b", result.index)
1374+
self.assertIn("mean(cvmd)", result.index)
1375+
1376+
# Verify all values are non-negative (CVMD property)
1377+
self.assertTrue((result >= 0).all())
1378+
1379+
# Test with aggregate_by_main_covar
1380+
result_agg = BalanceDF._cvmd_BalanceDF(
1381+
sample, target, aggregate_by_main_covar=True
1382+
)
1383+
self.assertIsInstance(result_agg, pd.Series)
1384+
1385+
def test_BalanceDF__ks_BalanceDF(self) -> None:
1386+
"""Test _ks_BalanceDF static method directly."""
1387+
sample = Sample.from_frame(
1388+
pd.DataFrame({"id": (1, 2), "a": (1, 2), "b": (-1, 12), "weight": (1, 2)})
1389+
).covars()
1390+
1391+
target = Sample.from_frame(
1392+
pd.DataFrame({"id": (1, 2), "a": (3, 4), "b": (0, 42), "weight": (1, 2)})
1393+
).covars()
1394+
1395+
result = BalanceDF._ks_BalanceDF(sample, target)
1396+
1397+
# Verify result is a Series with expected keys
1398+
self.assertIsInstance(result, pd.Series)
1399+
self.assertIn("a", result.index)
1400+
self.assertIn("b", result.index)
1401+
self.assertIn("mean(ks)", result.index)
1402+
1403+
# Verify all values are in [0, 1] (KS property)
1404+
self.assertTrue((result >= 0).all())
1405+
self.assertTrue((result <= 1).all())
1406+
1407+
# Test with aggregate_by_main_covar
1408+
result_agg = BalanceDF._ks_BalanceDF(
1409+
sample, target, aggregate_by_main_covar=True
1410+
)
1411+
self.assertIsInstance(result_agg, pd.Series)
1412+
1413+
def test_BalanceDF_comparison_functions_invalid_input(self) -> None:
1414+
"""Test that all comparison functions properly validate inputs."""
1415+
sample = Sample.from_frame(
1416+
pd.DataFrame({"id": (1, 2), "a": (1, 2), "weight": (1, 2)})
1417+
).covars()
1418+
1419+
# Test with non-BalanceDF inputs
1420+
invalid_input = "not a BalanceDF"
1421+
1422+
with self.assertRaisesRegex(ValueError, "must be balancedf_class.BalanceDF"):
1423+
BalanceDF._kld_BalanceDF(invalid_input, sample) # type: ignore
1424+
1425+
with self.assertRaisesRegex(ValueError, "must be balancedf_class.BalanceDF"):
1426+
BalanceDF._emd_BalanceDF(sample, invalid_input) # type: ignore
1427+
1428+
with self.assertRaisesRegex(ValueError, "must be balancedf_class.BalanceDF"):
1429+
BalanceDF._cvmd_BalanceDF(invalid_input, sample) # type: ignore
1430+
1431+
with self.assertRaisesRegex(ValueError, "must be balancedf_class.BalanceDF"):
1432+
BalanceDF._ks_BalanceDF(sample, invalid_input) # type: ignore
1433+
13041434

13051435
class TestBalanceDF_to_download(BalanceTestCase):
13061436
def test_BalanceDF_to_download(self) -> None:

0 commit comments

Comments
 (0)