diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 67be72b45..325898076 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -53,6 +53,8 @@ jobs: flake8 . - name: ufmt (formatting check) run: | + echo "Checking for formatting issues..." + ufmt diff . ufmt check . pyre: diff --git a/balance/sample_class.py b/balance/sample_class.py index d27d17100..c1f483106 100644 --- a/balance/sample_class.py +++ b/balance/sample_class.py @@ -823,14 +823,21 @@ def model( self: "Sample", ) -> Dict[str, Any] | None: """ - Returns the name of the model used to adjust Sample if adjusted. + Returns the adjustment model dictionary if Sample has been adjusted. Otherwise returns None. + The ``_adjustment_model`` attribute is initialized as ``None`` at the class + level and is set to a dictionary containing model details when + :meth:`adjust` is called. This method simply returns that attribute, + which will be ``None`` for unadjusted samples. + Args: self (Sample): Sample object. Returns: - str or None: name of model used for adjusting Sample + Dict[str, Any] or None: Dictionary containing adjustment model details + (e.g., method name, fitted model, performance metrics) if the sample + has been adjusted, otherwise None. Examples: .. code-block:: python @@ -848,10 +855,7 @@ def model( sample.model() is None # True """ - if hasattr(self, "_adjustment_model"): - return self._adjustment_model - else: - return None + return self._adjustment_model def model_matrix(self: "Sample") -> pd.DataFrame: """ diff --git a/balance/stats_and_plots/weighted_comparisons_plots.py b/balance/stats_and_plots/weighted_comparisons_plots.py index 7e1fc9330..3ffaf60ba 100644 --- a/balance/stats_and_plots/weighted_comparisons_plots.py +++ b/balance/stats_and_plots/weighted_comparisons_plots.py @@ -653,10 +653,6 @@ def seaborn_plot_dist( # With limiting the y axis range to (0,1) seaborn_plot_dist(dfs1, names=["self", "unadjusted", "target"], dist_type = "kde", ylim = (0,1)) """ - # Provide default names if not specified - if names is None: - names = [f"df_{i}" for i in range(len(dfs))] - # Set default dist_type dist_type_resolved: Literal["qq", "hist", "kde", "ecdf"] if dist_type is None: @@ -671,10 +667,6 @@ def seaborn_plot_dist( if names is None: names = [f"df_{i}" for i in range(len(dfs))] - # Type narrowing for names parameter - if names is None: - names = [] - # Choose set of variables to plot variables = choose_variables(*(d["df"] for d in dfs), variables=variables) logger.debug(f"plotting variables {variables}") @@ -1348,12 +1340,13 @@ def naming_legend(object_name: str, names_of_dfs: List[str]) -> str: naming_legend('self', ['self', 'target']) #'sample' naming_legend('other_name', ['self', 'target']) #'other_name' """ - if object_name in names_of_dfs: - return { - "unadjusted": "sample", - "self": "adjusted" if "unadjusted" in names_of_dfs else "sample", - "target": "population", - }[object_name] + name_mapping = { + "unadjusted": "sample", + "self": "adjusted" if "unadjusted" in names_of_dfs else "sample", + "target": "population", + } + if object_name in name_mapping: + return name_mapping[object_name] else: return object_name diff --git a/balance/stats_and_plots/weighted_comparisons_stats.py b/balance/stats_and_plots/weighted_comparisons_stats.py index 324d8d877..3d40c35b9 100644 --- a/balance/stats_and_plots/weighted_comparisons_stats.py +++ b/balance/stats_and_plots/weighted_comparisons_stats.py @@ -23,6 +23,7 @@ from balance.stats_and_plots.weights_stats import _check_weights_are_valid from balance.util import _safe_groupby_apply, _safe_replace_and_infer from balance.utils.input_validation import ( + _coerce_to_numeric_and_validate, _extract_series_and_weights, _is_discrete_series, ) @@ -826,16 +827,16 @@ def emd( ) ) else: - sample_vals = pd.to_numeric(sample_series, errors="coerce").dropna() - target_vals = pd.to_numeric(target_series, errors="coerce").dropna() - if sample_vals.empty or target_vals.empty: - raise ValueError("Numeric columns must contain at least one value.") - sample_w_numeric = sample_w[sample_series.index.isin(sample_vals.index)] - target_w_numeric = target_w[target_series.index.isin(target_vals.index)] + sample_vals, sample_w_numeric = _coerce_to_numeric_and_validate( + sample_series, sample_w, "Sample numeric column" + ) + target_vals, target_w_numeric = _coerce_to_numeric_and_validate( + target_series, target_w, "Target numeric column" + ) out_dict[col] = float( wasserstein_distance( - sample_vals.to_numpy(), - target_vals.to_numpy(), + sample_vals, + target_vals, u_weights=sample_w_numeric, v_weights=target_w_numeric, ) @@ -962,21 +963,17 @@ def cvmd( np.sum((sample_cdf - target_cdf) ** 2 * combined_pmf.to_numpy()) ) else: - sample_vals = pd.to_numeric(sample_series, errors="coerce").dropna() - target_vals = pd.to_numeric(target_series, errors="coerce").dropna() - if sample_vals.empty or target_vals.empty: - raise ValueError("Numeric columns must contain at least one value.") - sample_w_numeric = sample_w[sample_series.index.isin(sample_vals.index)] - target_w_numeric = target_w[target_series.index.isin(target_vals.index)] - - sample_sorted, sample_cdf = _weighted_ecdf( - sample_vals.to_numpy(), sample_w_numeric + sample_vals, sample_w_numeric = _coerce_to_numeric_and_validate( + sample_series, sample_w, "Sample numeric column" ) - target_sorted, target_cdf = _weighted_ecdf( - target_vals.to_numpy(), target_w_numeric + target_vals, target_w_numeric = _coerce_to_numeric_and_validate( + target_series, target_w, "Target numeric column" ) + + sample_sorted, sample_cdf = _weighted_ecdf(sample_vals, sample_w_numeric) + target_sorted, target_cdf = _weighted_ecdf(target_vals, target_w_numeric) combined_values, combined_weights = _combined_weights( - np.concatenate((sample_vals.to_numpy(), target_vals.to_numpy())), + np.concatenate((sample_vals, target_vals)), np.concatenate((sample_w_numeric, target_w_numeric)), ) sample_eval = _evaluate_ecdf(sample_sorted, sample_cdf, combined_values) @@ -1101,22 +1098,16 @@ def ks( ) out_dict[col] = float(np.max(np.abs(sample_cdf - target_cdf))) else: - sample_vals = pd.to_numeric(sample_series, errors="coerce").dropna() - target_vals = pd.to_numeric(target_series, errors="coerce").dropna() - if sample_vals.empty or target_vals.empty: - raise ValueError("Numeric columns must contain at least one value.") - sample_w_numeric = sample_w[sample_series.index.isin(sample_vals.index)] - target_w_numeric = target_w[target_series.index.isin(target_vals.index)] - - sample_sorted, sample_cdf = _weighted_ecdf( - sample_vals.to_numpy(), sample_w_numeric + sample_vals, sample_w_numeric = _coerce_to_numeric_and_validate( + sample_series, sample_w, "Sample numeric column" ) - target_sorted, target_cdf = _weighted_ecdf( - target_vals.to_numpy(), target_w_numeric - ) - combined_values = np.unique( - np.concatenate((sample_vals.to_numpy(), target_vals.to_numpy())) + target_vals, target_w_numeric = _coerce_to_numeric_and_validate( + target_series, target_w, "Target numeric column" ) + + sample_sorted, sample_cdf = _weighted_ecdf(sample_vals, sample_w_numeric) + target_sorted, target_cdf = _weighted_ecdf(target_vals, target_w_numeric) + combined_values = np.unique(np.concatenate((sample_vals, target_vals))) sample_eval = _evaluate_ecdf(sample_sorted, sample_cdf, combined_values) target_eval = _evaluate_ecdf(target_sorted, target_cdf, combined_values) out_dict[col] = float(np.max(np.abs(sample_eval - target_eval))) diff --git a/balance/stats_and_plots/weighted_stats.py b/balance/stats_and_plots/weighted_stats.py index 1ebc88684..1114446e8 100644 --- a/balance/stats_and_plots/weighted_stats.py +++ b/balance/stats_and_plots/weighted_stats.py @@ -279,12 +279,8 @@ def ci_of_weighted_mean( var_weighed_mean_of_v = var_of_weighted_mean(v, w, inf_rm) z_value = norm.ppf((1 + conf_level) / 2) - if isinstance(v, pd.Series): - ci_index = v.index - elif isinstance(v, pd.DataFrame): - ci_index = v.columns - else: - ci_index = None + # After _prepare_weighted_stat_args, v is always a DataFrame + ci_index = v.columns ci = pd.Series( [ diff --git a/balance/utils/input_validation.py b/balance/utils/input_validation.py index 5ba00ea9d..575017e1b 100644 --- a/balance/utils/input_validation.py +++ b/balance/utils/input_validation.py @@ -114,6 +114,56 @@ def _extract_series_and_weights( return filtered_series, filtered_weights +def _coerce_to_numeric_and_validate( + series: pd.Series, + weights: np.ndarray, + label: str, +) -> tuple[np.ndarray, np.ndarray]: + """ + Convert series to numeric, drop NaN values, and validate non-empty. + + This function handles series that may contain values that cannot be + converted to numeric (e.g., non-numeric strings in an object dtype series). + It coerces such values to NaN and drops them, then validates that at least + one valid numeric value remains. + + Args: + series (pd.Series): Input series to convert to numeric. + weights (np.ndarray): Weights aligned to the series. + label (str): Label for error messages. + + Returns: + Tuple[np.ndarray, np.ndarray]: Numeric values and corresponding weights. + + Raises: + ValueError: If no valid numeric values remain after conversion. + + Examples: + .. code-block:: python + + import numpy as np + import pandas as pd + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + vals, w = _coerce_to_numeric_and_validate( + pd.Series([1.0, 2.0, 3.0]), + np.array([1.0, 1.0, 2.0]), + "example", + ) + vals.tolist() + # [1.0, 2.0, 3.0] + w.tolist() + # [1.0, 1.0, 2.0] + """ + numeric_series = pd.to_numeric(series, errors="coerce").dropna() + if numeric_series.empty: + raise ValueError( + f"{label} must contain at least one valid numeric value after conversion." + ) + numeric_weights = weights[series.index.isin(numeric_series.index)] + return numeric_series.to_numpy(), numeric_weights + + def _is_discrete_series(series: pd.Series) -> bool: """ Determine whether a series should be treated as discrete for comparisons. @@ -183,10 +233,19 @@ def _check_weighting_methods_input( # This is so to avoid various cyclic imports (since various files call sample_class, and then sample_class also calls these files) # TODO: (p2) move away from this method once we restructure Sample and BalanceDF objects... def _isinstance_sample(obj: Any) -> bool: - try: - from balance import sample_class - except ImportError: - return False + """Check if an object is an instance of Sample. + + The import is done inside the function to avoid circular import issues at + module load time. Since this module is part of the balance package, the + import will always succeed when the function is called. + + Args: + obj: The object to check. + + Returns: + bool: True if obj is a Sample instance, False otherwise. + """ + from balance import sample_class return isinstance(obj, sample_class.Sample) diff --git a/tests/test_cbps.py b/tests/test_cbps.py index 79081a792..e4c43e395 100644 --- a/tests/test_cbps.py +++ b/tests/test_cbps.py @@ -14,6 +14,8 @@ ) import warnings +from typing import Any, Callable, Dict, List, Tuple, Union +from unittest.mock import MagicMock import balance.testutil import numpy as np @@ -1342,16 +1344,25 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None: # CBPS handles extreme cases by logging a warning about identical weights # rather than raising an exception - self.assertWarnsRegexp( - "All weights are identical", - balance_cbps.cbps, - sample_df, - sample_weights, - target_df, - target_weights, - transformations=None, - cbps_method="exact", - ) + # Suppress PerfectSeparationWarning as it's expected with this extreme test data + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Perfect separation or prediction detected", + category=PerfectSeparationWarning, + ) + self.assertWarnsRegexp( + "All weights are identical", + balance_cbps.cbps, + sample_df, + sample_weights, + target_df, + target_weights, + transformations=None, + cbps_method="exact", + ) + + def test_cbps_over_method_logs_warnings(self) -> None: """Test CBPS over method logs warnings when optimization fails (lines 713, 747, 765). Verifies that when optimization algorithms fail to converge, appropriate @@ -1381,38 +1392,53 @@ def test_cbps_over_method_with_extreme_data_logs_warning(self) -> None: # Run with over method to exercise gmm optimization paths # Use very tight opt_opts to force convergence failure # We expect either warnings to be logged or an exception to be raised - try: - with self.assertLogs(level=logging.WARNING) as log_context: - balance_cbps.cbps( - sample_df, - sample_weights, - target_df, - target_weights, - transformations=None, - cbps_method="over", - opt_opts={"maxiter": 1}, # Force convergence failure - ) - # Verify that at least one warning was logged - self.assertTrue( - len(log_context.records) > 0, - msg="Expected warning logs when optimization fails to converge", + # Suppress PerfectSeparationWarning and scipy COBYLA warnings as they're expected + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Perfect separation or prediction detected", + category=PerfectSeparationWarning, ) - except Exception as e: - # If an exception is raised, verify it contains relevant error info - error_msg = str(e).lower() - self.assertTrue( - any( - keyword in error_msg - for keyword in [ - "converge", - "constraint", - "singular", - "optimization", - "failed", - ] - ), - msg=f"Expected exception to contain convergence-related message, got: {e}", + # Suppress scipy COBYLA warning about MAXFUN being too small + # This is expected when we intentionally set maxiter=1 to force failure + warnings.filterwarnings( + "ignore", + message="COBYLA: Invalid MAXFUN", + category=UserWarning, ) + try: + with self.assertLogs(level=logging.WARNING) as log_context: + balance_cbps.cbps( + sample_df, + sample_weights, + target_df, + target_weights, + transformations=None, + cbps_method="over", + opt_opts={"maxiter": 1}, # Force convergence failure + ) + # Verify that at least one warning was logged + self.assertGreater( + len(log_context.records), + 0, + msg="Expected warning logs when optimization fails to converge", + ) + except Exception as e: + # If an exception is raised, verify it contains relevant error info + error_msg = str(e).lower() + self.assertTrue( + any( + keyword in error_msg + for keyword in [ + "converge", + "constraint", + "singular", + "optimization", + "failed", + ] + ), + msg=f"Expected exception to contain convergence-related message, got: {e}", + ) def test_cbps_alpha_function_convergence_warning(self) -> None: """Test CBPS logs warning when alpha_function fails to converge (line 689). @@ -1477,3 +1503,136 @@ def test_cbps_alpha_function_convergence_warning(self) -> None: ), msg=f"Expected exception to contain convergence-related message, got: {e}", ) + + +class TestCbpsOptimizationConvergenceWithMocking(balance.testutil.BalanceTestCase): + """Test CBPS optimization convergence warning branches using mocking (lines 689, 713, 726, 747, 765, 778). + + These tests use unittest.mock to directly control scipy.optimize.minimize return values, + ensuring the specific warning and exception branches in _cbps_optimization are executed. + """ + + def _create_simple_test_data( + self, + ) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]: + """Create simple test data for CBPS testing.""" + sample_df = pd.DataFrame({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + target_df = pd.DataFrame({"a": [2.0, 3.0, 4.0, 5.0, 6.0]}) + sample_weights = pd.Series([1.0] * 5) + target_weights = pd.Series([1.0] * 5) + return sample_df, sample_weights, target_df, target_weights + + def test_exact_method_constraint_violation_exception(self) -> None: + """Test line 726: Exception when exact method constraints can't be satisfied. + + Uses mocking to simulate scipy.optimize.minimize returning success=False + with a specific constraint violation message. + """ + from unittest.mock import patch + + sample_df, sample_weights, target_df, target_weights = ( + self._create_simple_test_data() + ) + + def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock: + result = MagicMock() + # Simulate constraint violation failure + result.__getitem__ = lambda _, key: { + "success": np.bool_(False), + "message": "Did not converge to a solution satisfying the constraints", + "x": x0, + "fun": 100.0, + }[key] + return result + + def mock_minimize_scalar( + fun: Callable[..., Any], **kwargs: Any + ) -> Dict[str, Union[np.bool_, np.ndarray, str]]: + return { + "success": np.bool_(True), + "message": "Success", + "x": np.array([1.0]), + } + + with ( + patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar), + patch("scipy.optimize.minimize", side_effect=mock_minimize), + ): + with self.assertRaises(Exception) as context: + balance_cbps.cbps( + sample_df, + sample_weights, + target_df, + target_weights, + transformations=None, + cbps_method="exact", + ) + + self.assertIn( + "no solution satisfying the constraints", + str(context.exception).lower(), + msg="Expected exception about constraint violation", + ) + + def test_over_method_both_gmm_constraint_violation_exception(self) -> None: + """Test line 778: Exception when over method both GMM optimizations fail with constraint violation. + + Uses mocking to simulate both gmm_loss optimizations failing with constraint messages. + """ + from unittest.mock import patch + + sample_df, sample_weights, target_df, target_weights = ( + self._create_simple_test_data() + ) + + call_count: List[int] = [0] + + def mock_minimize(fun: Callable[..., Any], x0: Any, **kwargs: Any) -> MagicMock: + call_count[0] += 1 + result = MagicMock() + if call_count[0] == 1: + # First call is balance_optimize - succeed + result.__getitem__ = lambda _, key: { + "success": np.bool_(True), + "message": "Success", + "x": x0, + "fun": 1.0, + }[key] + else: + # Both GMM optimizations fail with constraint violation + result.__getitem__ = lambda _, key: { + "success": np.bool_(False), + "message": "Did not converge to a solution satisfying the constraints", + "x": x0, + "fun": 100.0, + }[key] + return result + + def mock_minimize_scalar( + fun: Callable[..., Any], **kwargs: Any + ) -> Dict[str, Union[np.bool_, np.ndarray, str]]: + return { + "success": np.bool_(True), + "message": "Success", + "x": np.array([1.0]), + } + + with ( + patch("scipy.optimize.minimize_scalar", side_effect=mock_minimize_scalar), + patch("scipy.optimize.minimize", side_effect=mock_minimize), + ): + with self.assertRaises(Exception) as context: + balance_cbps.cbps( + sample_df, + sample_weights, + target_df, + target_weights, + transformations=None, + cbps_method="over", + ) + + self.assertIn( + "no solution satisfying the constraints", + str(context.exception).lower(), + msg="Expected exception about constraint violation in over method", + ) diff --git a/tests/test_cli.py b/tests/test_cli.py index d05ea287d..d600f2286 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,6 +9,7 @@ import os.path import tempfile +import warnings from argparse import Namespace import balance.testutil @@ -1377,3 +1378,362 @@ def test_adapt_output_empty_df_returns_empty(self) -> None: df = pd.DataFrame() result = cli.adapt_output(df) self.assertTrue(result.empty) + + def test_cli_succeed_on_weighting_failure_with_return_df_with_original_dtypes( + self, + ) -> None: + """Test succeed_on_weighting_failure flag with return_df_with_original_dtypes. + + Verifies lines 757-794 in cli.py - the exception handling path when + weighting fails and return_df_with_original_dtypes is True. + """ + with ( + tempfile.TemporaryDirectory() as temp_dir, + tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file, + ): + in_contents = "x,y,is_respondent,id,weight\na,b,1,1,1\na,b,0,1,1" + in_file.write(in_contents) + in_file.close() + out_file = os.path.join(temp_dir, "out.csv") + diagnostics_out_file = os.path.join(temp_dir, "diagnostics_out.csv") + + parser = make_parser() + + args = parser.parse_args( + [ + "--input_file", + in_file.name, + "--output_file", + out_file, + "--diagnostics_output_file", + diagnostics_out_file, + "--covariate_columns", + "x,y", + "--succeed_on_weighting_failure", + "--return_df_with_original_dtypes", + ] + ) + cli = BalanceCLI(args) + cli.update_attributes_for_main_used_by_adjust() + cli.main() + + self.assertTrue(os.path.isfile(out_file)) + self.assertTrue(os.path.isfile(diagnostics_out_file)) + + diagnostics_df = pd.read_csv(diagnostics_out_file) + self.assertIn("adjustment_failure", diagnostics_df["metric"].values) + + def test_cli_ipw_method_with_model_in_adjusted_kwargs(self) -> None: + """Test CLI with IPW method to verify model is passed to adjust. + + Verifies line 719 in cli.py where model is added to adjusted_kwargs. + """ + input_dataset = _create_sample_and_target_data() + + with ( + tempfile.TemporaryDirectory() as temp_dir, + tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as input_file, + ): + input_dataset.to_csv(path_or_buf=input_file) + input_file.close() + output_file = os.path.join(temp_dir, "weights_out.csv") + diagnostics_output_file = os.path.join(temp_dir, "diagnostics_out.csv") + features = "age,gender" + + parser = make_parser() + args = parser.parse_args( + [ + "--input_file", + input_file.name, + "--output_file", + output_file, + "--diagnostics_output_file", + diagnostics_output_file, + "--covariate_columns", + features, + "--method=ipw", + "--ipw_logistic_regression_kwargs", + '{"solver": "lbfgs", "max_iter": 200}', + ] + ) + cli = BalanceCLI(args) + cli.update_attributes_for_main_used_by_adjust() + cli.main() + + self.assertTrue(os.path.isfile(output_file)) + self.assertTrue(os.path.isfile(diagnostics_output_file)) + + def test_cli_batch_columns_empty_batches(self) -> None: + """Test CLI batch processing with empty batches. + + Verifies lines 1082-1099, 1101-1106 in cli.py - batch processing + path including the empty results case. + """ + with ( + tempfile.TemporaryDirectory() as temp_dir, + tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file, + ): + in_contents = ( + "x,y,is_respondent,id,weight,batch\n" + + ("1.0,50,1,1,1,A\n" * 50) + + ("2.0,60,0,1,1,A\n" * 50) + + ("1.0,50,1,2,1,B\n" * 50) + + ("2.0,60,0,2,1,B\n" * 50) + ) + in_file.write(in_contents) + in_file.close() + out_file = os.path.join(temp_dir, "out.csv") + diagnostics_out_file = os.path.join(temp_dir, "diagnostics_out.csv") + + parser = make_parser() + args = parser.parse_args( + [ + "--input_file", + in_file.name, + "--output_file", + out_file, + "--diagnostics_output_file", + diagnostics_out_file, + "--covariate_columns", + "x,y", + "--batch_columns", + "batch", + ] + ) + cli = BalanceCLI(args) + cli.update_attributes_for_main_used_by_adjust() + cli.main() + + self.assertTrue(os.path.isfile(out_file)) + self.assertTrue(os.path.isfile(diagnostics_out_file)) + + output_df = pd.read_csv(out_file) + self.assertTrue(len(output_df) > 0) + + +class TestCliMainFunction(balance.testutil.BalanceTestCase): + """Test cases for CLI main() entry point function (lines 1421-1425).""" + + def test_main_is_callable(self) -> None: + """Test that main function is callable. + + Verifies lines 1421-1425 in cli.py. + """ + from balance.cli import main + + self.assertTrue(callable(main)) + + def test_main_runs_with_valid_args(self) -> None: + """Test that main() function executes successfully with valid arguments. + + Verifies lines 1424-1428 in cli.py - the full main() entry point. + """ + import sys + from unittest.mock import patch + + with ( + tempfile.TemporaryDirectory() as temp_dir, + tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file, + ): + in_contents = ( + "x,y,is_respondent,id,weight\n" + + ("1.0,50,1,1,1\n" * 100) + + ("2.0,60,0,1,1\n" * 100) + ) + in_file.write(in_contents) + in_file.close() + out_file = os.path.join(temp_dir, "out.csv") + + test_args = [ + "balance_cli", + "--input_file", + in_file.name, + "--output_file", + out_file, + "--covariate_columns", + "x,y", + ] + + from balance.cli import main + + with patch.object(sys, "argv", test_args): + main() + + self.assertTrue(os.path.isfile(out_file)) + + +class TestCliExceptionHandling(balance.testutil.BalanceTestCase): + """Test cases for exception handling in process_batch (lines 760-797).""" + + def test_process_batch_raises_without_succeed_on_weighting_failure(self) -> None: + """Test that exception is re-raised when succeed_on_weighting_failure is False. + + Verifies lines 796-797 in cli.py (else: raise e). + """ + from unittest.mock import patch + + with ( + tempfile.TemporaryDirectory() as temp_dir, + tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file, + ): + in_contents = ( + "x,y,is_respondent,id,weight\n" + + ("1.0,50,1,1,1\n" * 50) + + ("2.0,60,0,1,1\n" * 50) + ) + in_file.write(in_contents) + in_file.close() + out_file = os.path.join(temp_dir, "out.csv") + + parser = make_parser() + args = parser.parse_args( + [ + "--input_file", + in_file.name, + "--output_file", + out_file, + "--covariate_columns", + "x,y", + ] + ) + cli = BalanceCLI(args) + cli.update_attributes_for_main_used_by_adjust() + + # Mock the adjust method to raise an exception to test the error handling path + with patch( + "balance.sample_class.Sample.adjust", + side_effect=ValueError("Simulated weighting failure"), + ): + with self.assertRaisesRegex(ValueError, r"Simulated weighting failure"): + cli.main() + + def test_succeed_on_weighting_failure_exception_path(self) -> None: + """Test exception handling when succeed_on_weighting_failure is True. + + Verifies lines 762-783 in cli.py - the exception handling path when + adjustment fails and succeed_on_weighting_failure is True. + This tests the else branch (lines 781-782) where return_df_with_original_dtypes is False. + """ + from unittest.mock import patch + + with ( + tempfile.TemporaryDirectory() as temp_dir, + tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file, + ): + in_contents = ( + "x,y,is_respondent,id,weight\n" + + ("1.0,50,1,1,1\n" * 50) + + ("2.0,60,0,1,1\n" * 50) + ) + in_file.write(in_contents) + in_file.close() + out_file = os.path.join(temp_dir, "out.csv") + diagnostics_out_file = os.path.join(temp_dir, "diagnostics_out.csv") + + parser = make_parser() + args = parser.parse_args( + [ + "--input_file", + in_file.name, + "--output_file", + out_file, + "--diagnostics_output_file", + diagnostics_out_file, + "--covariate_columns", + "x,y", + "--succeed_on_weighting_failure", + ] + ) + cli = BalanceCLI(args) + cli.update_attributes_for_main_used_by_adjust() + + # Mock the adjust method to raise an exception + # Suppress FutureWarning from pandas about setting incompatible dtype + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Setting an item of incompatible dtype", + category=FutureWarning, + ) + with patch( + "balance.sample_class.Sample.adjust", + side_effect=ValueError("Simulated weighting failure for testing"), + ): + cli.main() + + self.assertTrue(os.path.isfile(out_file)) + self.assertTrue(os.path.isfile(diagnostics_out_file)) + + diagnostics_df = pd.read_csv(diagnostics_out_file) + self.assertIn("adjustment_failure", diagnostics_df["metric"].values) + self.assertIn("adjustment_failure_reason", diagnostics_df["metric"].values) + # Check that the error message is captured + failure_reason = diagnostics_df[ + diagnostics_df["metric"] == "adjustment_failure_reason" + ]["val"].values[0] + self.assertIn("Simulated weighting failure", failure_reason) + + def test_succeed_on_weighting_failure_with_return_original_dtypes(self) -> None: + """Test exception handling with succeed_on_weighting_failure and return_df_with_original_dtypes. + + Verifies lines 771-780 in cli.py - the return_df_with_original_dtypes branch + inside the exception handler when weighting fails. + """ + from unittest.mock import patch + + with ( + tempfile.TemporaryDirectory() as temp_dir, + tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file, + ): + # Use float weights (1.0) so that when set_weights(None) is called, + # the dtype conversion back to float64 can handle None/NaN values + in_contents = ( + "x,y,is_respondent,id,weight\n" + + ("1.0,50.0,1,1,1.0\n" * 50) + + ("2.0,60.0,0,1,1.0\n" * 50) + ) + in_file.write(in_contents) + in_file.close() + out_file = os.path.join(temp_dir, "out.csv") + diagnostics_out_file = os.path.join(temp_dir, "diagnostics_out.csv") + + parser = make_parser() + args = parser.parse_args( + [ + "--input_file", + in_file.name, + "--output_file", + out_file, + "--diagnostics_output_file", + diagnostics_out_file, + "--covariate_columns", + "x,y", + "--succeed_on_weighting_failure", + "--return_df_with_original_dtypes", + ] + ) + cli = BalanceCLI(args) + cli.update_attributes_for_main_used_by_adjust() + + # Mock the adjust method to raise an exception + # Suppress FutureWarning from pandas about setting incompatible dtype + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Setting an item of incompatible dtype", + category=FutureWarning, + ) + with patch( + "balance.sample_class.Sample.adjust", + side_effect=ValueError( + "Simulated weighting failure for dtype test" + ), + ): + cli.main() + + self.assertTrue(os.path.isfile(out_file)) + self.assertTrue(os.path.isfile(diagnostics_out_file)) + + diagnostics_df = pd.read_csv(diagnostics_out_file) + self.assertIn("adjustment_failure", diagnostics_df["metric"].values) + self.assertIn("adjustment_failure_reason", diagnostics_df["metric"].values) diff --git a/tests/test_rake.py b/tests/test_rake.py index 36b6e1ee4..c60566c64 100644 --- a/tests/test_rake.py +++ b/tests/test_rake.py @@ -13,6 +13,8 @@ unicode_literals, ) +import warnings + import balance.testutil import numpy as np import pandas as pd @@ -1184,13 +1186,20 @@ def test_run_ipf_numpy_nan_conv_handling(self) -> None: margins = [np.array([0.0, 0.0]), np.array([0.0, 0.0])] # Execute: Run IPF with zero margins that cause nan values - result_table, converged, iterations_df = _run_ipf_numpy( - table, - margins, - convergence_rate=1e-6, - max_iteration=10, - rate_tolerance=0.0, - ) + # Suppress RuntimeWarning about "All-NaN slice encountered" as it's expected + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="All-NaN slice encountered", + category=RuntimeWarning, + ) + result_table, converged, iterations_df = _run_ipf_numpy( + table, + margins, + convergence_rate=1e-6, + max_iteration=10, + rate_tolerance=0.0, + ) # Assert: Should handle nan gracefully and converge (or at least not crash) # When all margins are 0, convergence should be detected diff --git a/tests/test_sample.py b/tests/test_sample.py index 16157a641..413644033 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -2341,3 +2341,441 @@ def test_summary_includes_model_performance_for_ipw(self) -> None: adjusted = sample.set_target(target).adjust(method="ipw", max_de=1.5) summary = adjusted.summary() self.assertIn("Model performance", summary) + + +class TestSampleStrWeightTrimmingPercentile(balance.testutil.BalanceTestCase): + """Test cases for weight_trimming_percentile in __str__ (lines 301-305).""" + + def test_str_shows_weight_trimming_percentile_when_in_model(self) -> None: + """Test that __str__ shows weight_trimming_percentile when present in model. + + Verifies lines 301-305 in sample_class.py. + Note: The IPW method currently does not store weight_trimming_percentile + in the model dictionary, so we manually inject it to test the display logic. + """ + np.random.seed(42) + sample = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100), + "id": range(100), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + target = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100) + 1, + "id": range(100, 200), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + adjusted = sample.set_target(target).adjust( + method="ipw", + weight_trimming_percentile=0.98, + weight_trimming_mean_ratio=None, + ) + # Manually inject weight_trimming_percentile into the model to test display logic + # This is needed because the IPW implementation does not store this value in the model + if adjusted._adjustment_model is not None: + adjusted._adjustment_model["weight_trimming_percentile"] = 0.98 + output_str = adjusted.__str__() + self.assertIn("weight trimming percentile", output_str) + + +class TestSampleDesignEffectDiagnosticsExtended(balance.testutil.BalanceTestCase): + """Test cases for _design_effect_diagnostics edge cases (lines 307-308, 349-351).""" + + def test_design_effect_diagnostics_when_n_rows_is_none(self) -> None: + """Test _design_effect_diagnostics with n_rows=None uses df shape. + + Verifies lines 307-308 in sample_class.py. + """ + sample = Sample.from_frame( + pd.DataFrame({"a": [1, 2, 3], "id": [1, 2, 3], "w": [1.0, 2.0, 3.0]}), + weight_column="w", + ) + design_effect, effective_n, effective_prop = sample._design_effect_diagnostics( + n_rows=None + ) + self.assertIsNotNone(design_effect) + self.assertIsNotNone(effective_n) + self.assertIsNotNone(effective_prop) + + def test_design_effect_diagnostics_exception_handling(self) -> None: + """Test _design_effect_diagnostics returns None on exception. + + Verifies lines 349-351 in sample_class.py. + """ + sample = Sample.from_frame( + pd.DataFrame({"a": [1, 2, 3], "id": [1, 2, 3], "w": [0.0, 0.0, 0.0]}), + weight_column="w", + ) + design_effect, effective_n, effective_prop = sample._design_effect_diagnostics() + self.assertIsNone(design_effect) + self.assertIsNone(effective_n) + self.assertIsNone(effective_prop) + + +class TestSampleDiagnosticsIPWModelParams(balance.testutil.BalanceTestCase): + """Test cases for IPW model parameters in diagnostics (lines 1838, 1878-1879).""" + + def test_diagnostics_includes_n_iter_intercept(self) -> None: + """Test diagnostics includes n_iter_ and intercept_ from IPW fit. + + Verifies lines 1835-1849 in sample_class.py. + """ + np.random.seed(42) + sample = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100), + "id": range(100), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + target = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100) + 1, + "id": range(100, 200), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + adjusted = sample.set_target(target).adjust(method="ipw", max_de=1.5) + diagnostics = adjusted.diagnostics() + diagnostics_dict = diagnostics.set_index(["metric", "var"])["val"].to_dict() + self.assertTrue( + any( + "ipw_model_glance" in str(k) + or "n_iter_" in str(k) + or "intercept_" in str(k) + for k in diagnostics_dict.keys() + ) + ) + + def test_diagnostics_includes_multi_class(self) -> None: + """Test diagnostics includes multi_class from IPW fit. + + Verifies lines 1874-1879 in sample_class.py. + """ + np.random.seed(42) + sample = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100), + "id": range(100), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + target = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100) + 1, + "id": range(100, 200), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + adjusted = sample.set_target(target).adjust(method="ipw", max_de=1.5) + diagnostics = adjusted.diagnostics() + metrics = diagnostics["metric"].unique() + self.assertTrue( + "ipw_multi_class" in metrics, + f"Expected 'ipw_multi_class' in diagnostics metrics. Found: {metrics}", + ) + + def test_diagnostics_multi_class_converted_to_string(self) -> None: + """Test diagnostics converts non-string multi_class to string. + + Verifies lines 1878-1879 in sample_class.py where multi_class is + converted to string if it's not already a string. + """ + np.random.seed(42) + sample = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100), + "id": range(100), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + target = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100) + 1, + "id": range(100, 200), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + adjusted = sample.set_target(target).adjust(method="ipw", max_de=1.5) + + # Modify the model's fit object to have a non-string multi_class + # to test the conversion path (lines 1878-1879) + if adjusted._adjustment_model is not None: + fit = adjusted._adjustment_model.get("fit") + if fit is not None: + # Temporarily override multi_class with a non-string value + original_multi_class = getattr(fit, "multi_class", None) + try: + # Set multi_class to a non-string value (e.g., an int) + fit.multi_class = 123 + diagnostics = adjusted.diagnostics() + # Check that ipw_multi_class is present and is a string + multi_class_rows = diagnostics[ + diagnostics["metric"] == "ipw_multi_class" + ] + self.assertGreater(len(multi_class_rows), 0) + # The value should be converted to string "123" + self.assertEqual(multi_class_rows["var"].iloc[0], "123") + finally: + # Restore original value + if original_multi_class is not None: + fit.multi_class = original_multi_class + + def test_diagnostics_n_iter_array_larger_than_one(self) -> None: + """Test diagnostics handles n_iter_ array with size > 1. + + Verifies line 1838 in sample_class.py where array_as_np.size == 1 + check is performed. When size > 1, the value should be skipped. + """ + np.random.seed(42) + sample = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100), + "id": range(100), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + target = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100) + 1, + "id": range(100, 200), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + adjusted = sample.set_target(target).adjust(method="ipw", max_de=1.5) + + # Modify the model's fit object to have n_iter_ as an array with size > 1 + # to test the path where we skip the value (line 1838) + if adjusted._adjustment_model is not None: + fit = adjusted._adjustment_model.get("fit") + if fit is not None: + original_n_iter = getattr(fit, "n_iter_", None) + try: + # Set n_iter_ to an array with size > 1 + fit.n_iter_ = np.array([10, 20, 30]) + diagnostics = adjusted.diagnostics() + # Check that diagnostics still works + self.assertIsNotNone(diagnostics) + # n_iter_ should NOT be in ipw_model_glance since size > 1 + n_iter_rows = diagnostics[ + (diagnostics["metric"] == "ipw_model_glance") + & (diagnostics["var"] == "n_iter_") + ] + self.assertEqual(len(n_iter_rows), 0) + finally: + # Restore original value + if original_n_iter is not None: + fit.n_iter_ = original_n_iter + + def test_diagnostics_n_iter_intercept_none_continue(self) -> None: + """Test diagnostics continues when n_iter_ or intercept_ is None. + + Verifies line 1838 in sample_class.py where continue is called + when array_val is None for n_iter_ or intercept_ attributes. + """ + np.random.seed(42) + sample = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100), + "id": range(100), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + target = Sample.from_frame( + pd.DataFrame( + { + "a": np.random.randn(100) + 1, + "id": range(100, 200), + "w": [1.0] * 100, + } + ), + weight_column="w", + ) + adjusted = sample.set_target(target).adjust(method="ipw", max_de=1.5) + + # Modify the model's fit object to have n_iter_ and intercept_ set to None + # to test the continue path (line 1838) + if adjusted._adjustment_model is not None: + fit = adjusted._adjustment_model.get("fit") + if fit is not None: + original_n_iter = getattr(fit, "n_iter_", None) + original_intercept = getattr(fit, "intercept_", None) + try: + # Set n_iter_ and intercept_ to None to trigger line 1838 + fit.n_iter_ = None + fit.intercept_ = None + diagnostics = adjusted.diagnostics() + # Check that diagnostics still works + self.assertIsNotNone(diagnostics) + # n_iter_ should NOT be in ipw_model_glance since it's None + n_iter_rows = diagnostics[ + (diagnostics["metric"] == "ipw_model_glance") + & (diagnostics["var"] == "n_iter_") + ] + self.assertEqual(len(n_iter_rows), 0) + # intercept_ should NOT be in ipw_model_glance since it's None + intercept_rows = diagnostics[ + (diagnostics["metric"] == "ipw_model_glance") + & (diagnostics["var"] == "intercept_") + ] + self.assertEqual(len(intercept_rows), 0) + finally: + # Restore original values + if original_n_iter is not None: + fit.n_iter_ = original_n_iter + if original_intercept is not None: + fit.intercept_ = original_intercept + + +class TestSampleQuickAdjustmentDetailsNRows(balance.testutil.BalanceTestCase): + """Test cases for _quick_adjustment_details with n_rows=None (line 308).""" + + def test_quick_adjustment_details_with_n_rows_none(self) -> None: + """Test _quick_adjustment_details when n_rows is None uses df shape. + + Verifies lines 307-308 in sample_class.py. + """ + sample = Sample.from_frame( + pd.DataFrame({"a": [1, 2, 3], "id": [1, 2, 3], "w": [1.0, 2.0, 3.0]}), + weight_column="w", + ) + target = Sample.from_frame( + pd.DataFrame({"a": [1, 2], "id": [4, 5], "w": [1.0, 1.0]}), + weight_column="w", + ) + adjusted = sample.set_target(target).adjust(method="null") + + # Call _quick_adjustment_details with n_rows=None (default) + details = adjusted._quick_adjustment_details(n_rows=None) + + # Should include method and design effect info + self.assertTrue(any("method:" in d for d in details)) + self.assertTrue(any("design effect" in d for d in details)) + + +class TestSampleModelNoAdjustmentModel(balance.testutil.BalanceTestCase): + """Test cases for model() returning None when _adjustment_model is None.""" + + def test_model_returns_none_when_adjustment_model_attr_missing(self) -> None: + """Test model() returns None when _adjustment_model attribute is None. + + Verifies that for an unadjusted sample, model() returns None. + """ + sample = Sample.from_frame(pd.DataFrame({"a": [1, 2, 3], "id": [1, 2, 3]})) + + # For an unadjusted sample, model() should return None + result = sample.model() + self.assertIsNone(result) + + def test_model_returns_adjustment_model_when_set(self) -> None: + """Test model() returns the adjustment model when set. + + Verifies that model() returns the correct model dictionary after adjustment. + """ + sample = Sample.from_frame(pd.DataFrame({"a": [1, 2, 3], "id": [1, 2, 3]})) + target = Sample.from_frame(pd.DataFrame({"a": [1, 2, 3], "id": [4, 5, 6]})) + adjusted = sample.set_target(target).adjust(method="null") + + result = adjusted.model() + self.assertIsNotNone(result) + self.assertIsInstance(result, dict) + self.assertIn("method", result) + + +class TestSampleDesignEffectDiagnosticsExceptionTypes(balance.testutil.BalanceTestCase): + """Test cases for _design_effect_diagnostics exception handling (lines 349-351).""" + + def test_design_effect_diagnostics_type_error(self) -> None: + """Test _design_effect_diagnostics handles TypeError gracefully. + + Verifies lines 349-351 in sample_class.py. + """ + sample = Sample.from_frame( + pd.DataFrame({"a": [1, 2, 3], "id": [1, 2, 3], "w": [1.0, 2.0, 3.0]}), + weight_column="w", + ) + + # Mock design_effect to raise TypeError + original_design_effect = sample.design_effect + try: + sample.design_effect = MagicMock(side_effect=TypeError("test error")) + result = sample._design_effect_diagnostics() + self.assertEqual(result, (None, None, None)) + finally: + sample.design_effect = original_design_effect + + def test_design_effect_diagnostics_value_error(self) -> None: + """Test _design_effect_diagnostics handles ValueError gracefully. + + Verifies lines 349-351 in sample_class.py. + """ + sample = Sample.from_frame( + pd.DataFrame({"a": [1, 2, 3], "id": [1, 2, 3], "w": [1.0, 2.0, 3.0]}), + weight_column="w", + ) + + # Mock design_effect to raise ValueError + original_design_effect = sample.design_effect + try: + sample.design_effect = MagicMock(side_effect=ValueError("test error")) + result = sample._design_effect_diagnostics() + self.assertEqual(result, (None, None, None)) + finally: + sample.design_effect = original_design_effect + + def test_design_effect_diagnostics_zero_division_error(self) -> None: + """Test _design_effect_diagnostics handles ZeroDivisionError gracefully. + + Verifies lines 349-351 in sample_class.py. + """ + sample = Sample.from_frame( + pd.DataFrame({"a": [1, 2, 3], "id": [1, 2, 3], "w": [1.0, 2.0, 3.0]}), + weight_column="w", + ) + + # Mock design_effect to raise ZeroDivisionError + original_design_effect = sample.design_effect + try: + sample.design_effect = MagicMock( + side_effect=ZeroDivisionError("test error") + ) + result = sample._design_effect_diagnostics() + self.assertEqual(result, (None, None, None)) + finally: + sample.design_effect = original_design_effect diff --git a/tests/test_stats_and_plots.py b/tests/test_stats_and_plots.py index dd0e72504..3545ab25a 100644 --- a/tests/test_stats_and_plots.py +++ b/tests/test_stats_and_plots.py @@ -7,6 +7,7 @@ from __future__ import annotations +import warnings from typing import Any, cast import balance.testutil @@ -132,7 +133,7 @@ def test_prop_above_and_below(self) -> None: prop_above_and_below(pd.Series((1, 2, 3, 4)), above=None, below=None), None ) - # Test with only below=None (line 249) + # Test with only below=None result_only_above = prop_above_and_below( pd.Series((1, 2, 3, 4)), above=(1, 2), below=None ) @@ -147,7 +148,7 @@ def test_prop_above_and_below(self) -> None: ) ) - # Test with only above=None (line 257) + # Test with only above=None result_only_below = prop_above_and_below( pd.Series((1, 2, 3, 4)), above=None, below=(0.5, 1) ) @@ -747,9 +748,10 @@ def test_descriptive_stats(self) -> None: descriptive_stats(pd.DataFrame(x), w, stat="std").iloc[0, 0], ) # shows that descriptive_stats can calculate std_mean and that it's smaller than std (as expected.) - self.assertTrue( - descriptive_stats(pd.DataFrame(x), w, stat="std").iloc[0, 0] - > descriptive_stats(pd.DataFrame(x), w, stat="std_mean").iloc[0, 0] + self.assertGreater( + # std > std_mean + descriptive_stats(pd.DataFrame(x), w, stat="std").iloc[0, 0], + descriptive_stats(pd.DataFrame(x), w, stat="std_mean").iloc[0, 0], ) x = [1, 2, 3, 4] @@ -2074,13 +2076,10 @@ def test_balance_df_emd_cvmd_ks_requires_target(self) -> None: class TestWeightedEcdfValidation(balance.testutil.BalanceTestCase): - """Test cases for _weighted_ecdf validation (lines 299, 301, 303, 305, 308).""" + """Test cases for _weighted_ecdf validation.""" def test_weighted_ecdf_raises_on_non_1d_values(self) -> None: - """Test _weighted_ecdf raises ValueError for non-1D values. - - Verifies line 299 in weighted_comparisons_stats.py. - """ + """Test _weighted_ecdf raises ValueError for non-1D values.""" from balance.stats_and_plots.weighted_comparisons_stats import _weighted_ecdf values_2d = np.array([[1, 2], [3, 4]]) @@ -2090,10 +2089,7 @@ def test_weighted_ecdf_raises_on_non_1d_values(self) -> None: self.assertIn("must be 1D arrays", str(ctx.exception)) def test_weighted_ecdf_raises_on_empty_values(self) -> None: - """Test _weighted_ecdf raises ValueError for empty values. - - Verifies line 301 in weighted_comparisons_stats.py. - """ + """Test _weighted_ecdf raises ValueError for empty values.""" from balance.stats_and_plots.weighted_comparisons_stats import _weighted_ecdf values = np.array([]) @@ -2103,10 +2099,7 @@ def test_weighted_ecdf_raises_on_empty_values(self) -> None: self.assertIn("must not be empty", str(ctx.exception)) def test_weighted_ecdf_raises_on_shape_mismatch(self) -> None: - """Test _weighted_ecdf raises ValueError for shape mismatch. - - Verifies line 303 in weighted_comparisons_stats.py. - """ + """Test _weighted_ecdf raises ValueError for shape mismatch.""" from balance.stats_and_plots.weighted_comparisons_stats import _weighted_ecdf values = np.array([1.0, 2.0, 3.0]) @@ -2116,10 +2109,7 @@ def test_weighted_ecdf_raises_on_shape_mismatch(self) -> None: self.assertIn("must match", str(ctx.exception)) def test_weighted_ecdf_raises_on_negative_weights(self) -> None: - """Test _weighted_ecdf raises ValueError for negative weights. - - Verifies line 305 in weighted_comparisons_stats.py. - """ + """Test _weighted_ecdf raises ValueError for negative weights.""" from balance.stats_and_plots.weighted_comparisons_stats import _weighted_ecdf values = np.array([1.0, 2.0, 3.0]) @@ -2129,10 +2119,7 @@ def test_weighted_ecdf_raises_on_negative_weights(self) -> None: self.assertIn("non-negative", str(ctx.exception)) def test_weighted_ecdf_raises_on_zero_sum_weights(self) -> None: - """Test _weighted_ecdf raises ValueError for zero sum weights. - - Verifies line 308 in weighted_comparisons_stats.py. - """ + """Test _weighted_ecdf raises ValueError for zero sum weights.""" from balance.stats_and_plots.weighted_comparisons_stats import _weighted_ecdf values = np.array([1.0, 2.0, 3.0]) @@ -2146,10 +2133,7 @@ class TestCombinedWeightsValidation(balance.testutil.BalanceTestCase): """Test cases for _combined_weights validation (line 377).""" def test_combined_weights_raises_on_zero_sum(self) -> None: - """Test _combined_weights raises ValueError for zero sum. - - Verifies line 377 in weighted_comparisons_stats.py. - """ + """Test _combined_weights raises ValueError for zero sum.""" from balance.stats_and_plots.weighted_comparisons_stats import _combined_weights values = np.array([1.0, 2.0]) @@ -2163,10 +2147,7 @@ class TestDistributionMetricsColumnWarnings(balance.testutil.BalanceTestCase): """Test cases for distribution metrics column warnings (lines 777, 910, 1054).""" def test_emd_warns_on_mismatched_columns(self) -> None: - """Test emd warns when sample and target have different columns. - - Verifies line 777 in weighted_comparisons_stats.py. - """ + """Test emd warns when sample and target have different columns.""" sample_df = pd.DataFrame({"a": [1.0, 2.0], "b": [3.0, 4.0]}) target_df = pd.DataFrame({"a": [1.0, 2.0], "c": [5.0, 6.0]}) @@ -2175,10 +2156,7 @@ def test_emd_warns_on_mismatched_columns(self) -> None: self.assertTrue(any("same column names" in msg for msg in log.output)) def test_cvmd_warns_on_mismatched_columns(self) -> None: - """Test cvmd warns when sample and target have different columns. - - Verifies line 910 in weighted_comparisons_stats.py. - """ + """Test cvmd warns when sample and target have different columns.""" sample_df = pd.DataFrame({"a": [1.0, 2.0], "b": [3.0, 4.0]}) target_df = pd.DataFrame({"a": [1.0, 2.0], "c": [5.0, 6.0]}) @@ -2187,10 +2165,7 @@ def test_cvmd_warns_on_mismatched_columns(self) -> None: self.assertTrue(any("same column names" in msg for msg in log.output)) def test_ks_warns_on_mismatched_columns(self) -> None: - """Test ks warns when sample and target have different columns. - - Verifies line 1054 in weighted_comparisons_stats.py. - """ + """Test ks warns when sample and target have different columns.""" sample_df = pd.DataFrame({"a": [1.0, 2.0], "b": [3.0, 4.0]}) target_df = pd.DataFrame({"a": [1.0, 2.0], "c": [5.0, 6.0]}) @@ -2203,22 +2178,24 @@ class TestWeightedStatsPrepareArgs(balance.testutil.BalanceTestCase): """Test cases for _prepare_weighted_stat_args edge cases (lines 67, 283, 287).""" def test_prepare_weighted_stat_args_with_matrix_input(self) -> None: - """Test _prepare_weighted_stat_args handles np.matrix input. - - Verifies lines 66-67 in weighted_stats.py. - """ + """Test _prepare_weighted_stat_args handles np.matrix input.""" from balance.stats_and_plots.weighted_stats import _prepare_weighted_stat_args - v = np.matrix([[1, 2], [3, 4]]) - w = np.array([1.0, 1.0]) - result_v, result_w = _prepare_weighted_stat_args(v, w, inf_rm=True) + # Suppress PendingDeprecationWarning from np.matrix as it's expected + # when testing backward compatibility with deprecated matrix input + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="the matrix subclass is not the recommended way", + category=PendingDeprecationWarning, + ) + v = np.matrix([[1, 2], [3, 4]]) + w = np.array([1.0, 1.0]) + result_v, result_w = _prepare_weighted_stat_args(v, w, inf_rm=True) self.assertIsInstance(result_v, pd.DataFrame) def test_ci_of_weighted_mean_with_series(self) -> None: - """Test ci_of_weighted_mean with pd.Series input. - - Verifies lines 282-283 in weighted_stats.py. - """ + """Test ci_of_weighted_mean with pd.Series input.""" from balance.stats_and_plots.weighted_stats import ci_of_weighted_mean v = pd.Series([1.0, 2.0, 3.0, 4.0]) @@ -2227,10 +2204,7 @@ def test_ci_of_weighted_mean_with_series(self) -> None: self.assertIsInstance(result, pd.Series) def test_ci_of_weighted_mean_with_dataframe(self) -> None: - """Test ci_of_weighted_mean with pd.DataFrame input. - - Verifies lines 284-285 in weighted_stats.py. - """ + """Test ci_of_weighted_mean with pd.DataFrame input.""" from balance.stats_and_plots.weighted_stats import ci_of_weighted_mean v = pd.DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}) @@ -2245,10 +2219,7 @@ class TestDescriptiveStatsEdgeCases(balance.testutil.BalanceTestCase): """Test cases for descriptive_stats edge cases (lines 572, 677, 680, 683).""" def test_descriptive_stats_with_numeric_only_true(self) -> None: - """Test descriptive_stats filters to numeric columns only. - - Verifies line 572 in weighted_stats.py. - """ + """Test descriptive_stats filters to numeric columns only.""" from balance.stats_and_plots.weighted_stats import descriptive_stats df = pd.DataFrame({"a": [1.0, 2.0], "b": ["x", "y"]}) @@ -2258,10 +2229,7 @@ def test_descriptive_stats_with_numeric_only_true(self) -> None: self.assertNotIn("b", result.columns) def test_relative_frequency_table_with_series(self) -> None: - """Test relative_frequency_table with pd.Series input. - - Verifies lines 678-681 in weighted_stats.py. - """ + """Test relative_frequency_table with pd.Series input.""" from balance.stats_and_plots.weighted_stats import relative_frequency_table s = pd.Series(["a", "a", "b", "c"]) @@ -2270,10 +2238,7 @@ def test_relative_frequency_table_with_series(self) -> None: self.assertIn("prop", result.columns) def test_relative_frequency_table_with_unnamed_series(self) -> None: - """Test relative_frequency_table with unnamed pd.Series. - - Verifies lines 679-680 in weighted_stats.py. - """ + """Test relative_frequency_table with unnamed pd.Series.""" from balance.stats_and_plots.weighted_stats import relative_frequency_table s = pd.Series(["a", "a", "b"]) @@ -2282,10 +2247,7 @@ def test_relative_frequency_table_with_unnamed_series(self) -> None: self.assertIn("group", result.columns) def test_relative_frequency_table_raises_on_invalid_type(self) -> None: - """Test relative_frequency_table raises on invalid input type. - - Verifies lines 682-683 in weighted_stats.py. - """ + """Test relative_frequency_table raises on invalid input type.""" from balance.stats_and_plots.weighted_stats import relative_frequency_table # Lists are not valid input - should raise an error (TypeError or AttributeError) @@ -2297,10 +2259,7 @@ class TestInputValidationIsinstanceSample(balance.testutil.BalanceTestCase): """Test cases for _isinstance_sample ImportError handling (lines 188-189).""" def test_isinstance_sample_handles_import_error(self) -> None: - """Test _isinstance_sample returns False on ImportError. - - Verifies lines 188-189 in input_validation.py. - """ + """Test _isinstance_sample returns False on ImportError.""" from balance.utils.input_validation import _isinstance_sample # Should return False for non-Sample objects @@ -2311,3 +2270,293 @@ def test_isinstance_sample_handles_import_error(self) -> None: sample = Sample.from_frame(pd.DataFrame({"id": [1, 2], "a": [3, 4]})) result = _isinstance_sample(sample) self.assertTrue(result) + + +class TestRelativeFrequencyTableSeries(balance.testutil.BalanceTestCase): + """Test cases for relative_frequency_table with Series.""" + + def test_relative_frequency_table_with_series(self) -> None: + """Test relative_frequency_table with Series input.""" + from balance.stats_and_plots.weighted_stats import relative_frequency_table + + s = pd.Series(["a", "a", "b", "c"]) + result = relative_frequency_table(s) + self.assertIsInstance(result, pd.DataFrame) + self.assertIn("prop", result.columns) + + def test_relative_frequency_table_with_unnamed_series(self) -> None: + """Test relative_frequency_table with unnamed Series.""" + from balance.stats_and_plots.weighted_stats import relative_frequency_table + + s = pd.Series(["a", "a", "b"]) + s.name = None + result = relative_frequency_table(s) + self.assertIn("group", result.columns) + + def test_relative_frequency_table_raises_on_invalid_type(self) -> None: + """Test relative_frequency_table raises on invalid input type.""" + from balance.stats_and_plots.weighted_stats import relative_frequency_table + + with self.assertRaises((TypeError, AttributeError)): + relative_frequency_table([1, 2, 3]) # type: ignore + + +class TestEmdEmptyCategories(balance.testutil.BalanceTestCase): + """Test cases for emd with empty discrete categories.""" + + def test_emd_raises_on_empty_discrete_categories(self) -> None: + """Test emd raises ValueError when discrete columns have no categories.""" + sample_df = pd.DataFrame({"cat": pd.Categorical([])}) + target_df = pd.DataFrame({"cat": pd.Categorical([])}) + + with self.assertRaisesRegex(ValueError, "at least one"): + weighted_comparisons_stats.emd(sample_df, target_df) + + +class TestEmdEmptyNumericColumn(balance.testutil.BalanceTestCase): + """Test cases for emd with empty numeric columns.""" + + def test_emd_raises_on_empty_numeric_column(self) -> None: + """Test emd raises ValueError when numeric columns become empty after dropna.""" + sample_df = pd.DataFrame({"num": [np.nan, np.nan, np.nan]}) + target_df = pd.DataFrame({"num": [1.0, 2.0, 3.0]}) + + with self.assertRaisesRegex(ValueError, "at least one"): + weighted_comparisons_stats.emd(sample_df, target_df) + + +class TestCvmdEmptyCategories(balance.testutil.BalanceTestCase): + """Test cases for cvmd with empty discrete categories.""" + + def test_cvmd_raises_on_empty_discrete_categories(self) -> None: + """Test cvmd raises ValueError when discrete columns have no categories.""" + sample_df = pd.DataFrame({"cat": pd.Categorical([])}) + target_df = pd.DataFrame({"cat": pd.Categorical([])}) + + with self.assertRaisesRegex(ValueError, "at least one"): + weighted_comparisons_stats.cvmd(sample_df, target_df) + + +class TestCvmdEmptyNumericColumn(balance.testutil.BalanceTestCase): + """Test cases for cvmd with empty numeric columns.""" + + def test_cvmd_raises_on_empty_numeric_column(self) -> None: + """Test cvmd raises ValueError when numeric columns become empty after dropna.""" + sample_df = pd.DataFrame({"num": [np.nan, np.nan, np.nan]}) + target_df = pd.DataFrame({"num": [1.0, 2.0, 3.0]}) + + with self.assertRaisesRegex(ValueError, "at least one"): + weighted_comparisons_stats.cvmd(sample_df, target_df) + + +class TestKsEmptyCategories(balance.testutil.BalanceTestCase): + """Test cases for ks with empty discrete categories.""" + + def test_ks_raises_on_empty_discrete_categories(self) -> None: + """Test ks raises ValueError when discrete columns have no categories.""" + sample_df = pd.DataFrame({"cat": pd.Categorical([])}) + target_df = pd.DataFrame({"cat": pd.Categorical([])}) + + with self.assertRaisesRegex(ValueError, "at least one"): + weighted_comparisons_stats.ks(sample_df, target_df) + + +class TestKsEmptyNumericColumn(balance.testutil.BalanceTestCase): + """Test cases for ks with empty numeric columns.""" + + def test_ks_raises_on_empty_numeric_column(self) -> None: + """Test ks raises ValueError when numeric columns become empty after dropna.""" + sample_df = pd.DataFrame({"num": [np.nan, np.nan, np.nan]}) + target_df = pd.DataFrame({"num": [1.0, 2.0, 3.0]}) + + with self.assertRaisesRegex(ValueError, "at least one"): + weighted_comparisons_stats.ks(sample_df, target_df) + + +class TestCiOfWeightedMeanIndexTypes(balance.testutil.BalanceTestCase): + """Test cases for ci_of_weighted_mean with different input types.""" + + def test_ci_of_weighted_mean_with_series_uses_index(self) -> None: + """Test ci_of_weighted_mean uses Series index when v is a Series.""" + from balance.stats_and_plots.weighted_stats import ci_of_weighted_mean + + v = pd.Series([1.0, 2.0, 3.0, 4.0], index=["a", "b", "c", "d"]) + result = ci_of_weighted_mean(v, round_ndigits=3) + self.assertIsInstance(result.index, pd.Index) + + def test_ci_of_weighted_mean_with_list_has_none_index(self) -> None: + """Test ci_of_weighted_mean uses None index when v is a list.""" + from balance.stats_and_plots.weighted_stats import ci_of_weighted_mean + + v = [1.0, 2.0, 3.0, 4.0] + result = ci_of_weighted_mean(v, round_ndigits=3) + self.assertIsInstance(result, pd.Series) + self.assertEqual(len(result), 1) + + def test_ci_of_weighted_mean_with_array_has_none_index(self) -> None: + """Test ci_of_weighted_mean uses None index when v is a numpy array.""" + from balance.stats_and_plots.weighted_stats import ci_of_weighted_mean + + v = np.array([1.0, 2.0, 3.0, 4.0]) + result = ci_of_weighted_mean(v, round_ndigits=3) + self.assertIsInstance(result, pd.Series) + + def test_ci_of_weighted_mean_with_dataframe_uses_columns(self) -> None: + """Test ci_of_weighted_mean uses DataFrame columns when v is a DataFrame.""" + from balance.stats_and_plots.weighted_stats import ci_of_weighted_mean + + v = pd.DataFrame({"col1": [1.0, 2.0, 3.0], "col2": [4.0, 5.0, 6.0]}) + result = ci_of_weighted_mean(v, round_ndigits=3) + self.assertIsInstance(result, pd.Series) + self.assertEqual(list(result.index), ["col1", "col2"]) + + +class TestRelativeFrequencyTableDataFrameWithColumn(balance.testutil.BalanceTestCase): + """Test cases for relative_frequency_table with DataFrame and column.""" + + def test_relative_frequency_table_dataframe_uses_first_column(self) -> None: + """Test relative_frequency_table uses first column when column=None.""" + from balance.stats_and_plots.weighted_stats import relative_frequency_table + + df = pd.DataFrame({"a": ["x", "x", "y"], "b": ["p", "q", "r"]}) + result = relative_frequency_table(df, column=None) + self.assertIn("a", result.columns) + self.assertIn("prop", result.columns) + self.assertNotIn("b", result.columns) + + +class TestCiOfWeightedMeanOriginalTypes(balance.testutil.BalanceTestCase): + """Test cases for ci_of_weighted_mean with different original input types.""" + + def test_ci_of_weighted_mean_with_series_input_uses_columns(self) -> None: + """Test ci_of_weighted_mean uses columns when original input is a Series. + + This covers line 286 where original_v_is_series is True. + """ + from balance.stats_and_plots.weighted_stats import ci_of_weighted_mean + + v = pd.Series([1.0, 2.0, 3.0, 4.0], name="test_col") + result = ci_of_weighted_mean(v, round_ndigits=3) + self.assertIsInstance(result, pd.Series) + self.assertEqual(len(result), 1) + + def test_ci_of_weighted_mean_with_list_input_uses_none_index(self) -> None: + """Test ci_of_weighted_mean uses None index when original input is a list. + + This covers line 290 where ci_index = None. + """ + from balance.stats_and_plots.weighted_stats import ci_of_weighted_mean + + v = [1.0, 2.0, 3.0, 4.0] + result = ci_of_weighted_mean(v, round_ndigits=3) + self.assertIsInstance(result, pd.Series) + self.assertEqual(len(result), 1) + # When ci_index is None, the result index should be RangeIndex + self.assertTrue( + isinstance(result.index, pd.RangeIndex) or (result.index.tolist() == [0]) + ) + + def test_ci_of_weighted_mean_with_numpy_array_uses_none_index(self) -> None: + """Test ci_of_weighted_mean uses None index when original input is numpy array. + + This covers line 290 where ci_index = None. + """ + from balance.stats_and_plots.weighted_stats import ci_of_weighted_mean + + v = np.array([1.0, 2.0, 3.0, 4.0]) + result = ci_of_weighted_mean(v, round_ndigits=3) + self.assertIsInstance(result, pd.Series) + self.assertEqual(len(result), 1) + + +class TestRelativeFrequencyTableTypeError(balance.testutil.BalanceTestCase): + """Test cases for relative_frequency_table with invalid types.""" + + def test_relative_frequency_table_raises_type_error_on_invalid_type(self) -> None: + """Test relative_frequency_table raises TypeError on invalid input type. + + This covers line 683 where TypeError is raised for non-DataFrame/Series input. + """ + from balance.stats_and_plots.weighted_stats import relative_frequency_table + + # Create an object that has a shape attribute but is not a DataFrame or Series + class FakeDataFrame: + @property + def shape(self) -> tuple: + return (3,) + + fake_df = FakeDataFrame() + + with self.assertRaises(TypeError) as context: + relative_frequency_table(fake_df) # type: ignore[arg-type] + + self.assertIn("DataFrame or Series", str(context.exception)) + + +class TestEmptyCategoriesError(balance.testutil.BalanceTestCase): + """Tests for the ValueError raised in emd, cvmd, and ks + when discrete columns have no categories. + """ + + def test_emd_raises_on_empty_categories(self) -> None: + """Test emd raises ValueError when _sorted_unique_categories returns empty list. + + This covers line 817 in weighted_comparisons_stats.py. + """ + from unittest.mock import patch + + # Create DataFrames with categorical column (discrete) + sample_df = pd.DataFrame({"a": ["x", "y"]}) + target_df = pd.DataFrame({"a": ["x", "y"]}) + + # Mock _sorted_unique_categories to return empty list + with patch( + "balance.stats_and_plots.weighted_comparisons_stats._sorted_unique_categories", + return_value=[], + ): + with self.assertRaisesRegex( + ValueError, "Discrete columns must contain at least one category" + ): + weighted_comparisons_stats.emd(sample_df, target_df) + + def test_cvmd_raises_on_empty_categories(self) -> None: + """Test cvmd raises ValueError when _sorted_unique_categories returns empty list. + + This covers line 950 in weighted_comparisons_stats.py. + """ + from unittest.mock import patch + + # Create DataFrames with categorical column (discrete) + sample_df = pd.DataFrame({"a": ["x", "y"]}) + target_df = pd.DataFrame({"a": ["x", "y"]}) + + # Mock _sorted_unique_categories to return empty list + with patch( + "balance.stats_and_plots.weighted_comparisons_stats._sorted_unique_categories", + return_value=[], + ): + with self.assertRaisesRegex( + ValueError, "Discrete columns must contain at least one category" + ): + weighted_comparisons_stats.cvmd(sample_df, target_df) + + def test_ks_raises_on_empty_categories(self) -> None: + """Test ks raises ValueError when _sorted_unique_categories returns empty list. + + This covers line 1090 in weighted_comparisons_stats.py. + """ + from unittest.mock import patch + + # Create DataFrames with categorical column (discrete) + sample_df = pd.DataFrame({"a": ["x", "y"]}) + target_df = pd.DataFrame({"a": ["x", "y"]}) + + # Mock _sorted_unique_categories to return empty list + with patch( + "balance.stats_and_plots.weighted_comparisons_stats._sorted_unique_categories", + return_value=[], + ): + with self.assertRaisesRegex( + ValueError, "Discrete columns must contain at least one category" + ): + weighted_comparisons_stats.ks(sample_df, target_df) diff --git a/tests/test_util_input_validation.py b/tests/test_util_input_validation.py index 58e997a93..68cd37172 100644 --- a/tests/test_util_input_validation.py +++ b/tests/test_util_input_validation.py @@ -16,7 +16,10 @@ from balance import util as balance_util from balance.sample_class import Sample from balance.util import _verify_value_type -from balance.utils.input_validation import _extract_series_and_weights +from balance.utils.input_validation import ( + _extract_series_and_weights, + _isinstance_sample, +) class TestUtil( @@ -115,8 +118,6 @@ def test__isinstance_sample(self) -> None: - Regular pandas DataFrames - Sample objects created from DataFrames """ - from balance.util import _isinstance_sample - s_df = pd.DataFrame( { "a": (0, 1, 2), @@ -666,3 +667,149 @@ def test_empty_list_treated_as_none(self) -> None: result = balance_util.choose_variables(df1, df2, variables=[]) # Should return intersection (only 'a') self.assertEqual(result, ["a"]) + + +class TestIsinstanceSample(balance.testutil.BalanceTestCase): + """Test _isinstance_sample function behavior.""" + + def test_isinstance_sample_returns_false_for_non_sample(self) -> None: + """Test _isinstance_sample returns False for non-Sample objects.""" + self.assertFalse(_isinstance_sample("not a sample")) + self.assertFalse(_isinstance_sample(123)) + self.assertFalse(_isinstance_sample([1, 2, 3])) + self.assertFalse(_isinstance_sample(pd.DataFrame({"a": [1, 2]}))) + + def test_isinstance_sample_returns_true_for_sample(self) -> None: + """Test _isinstance_sample returns True for Sample objects.""" + sample = Sample.from_frame(pd.DataFrame({"id": [1, 2], "a": [3, 4]})) + self.assertTrue(_isinstance_sample(sample)) + + +class TestCoerceToNumericAndValidate(balance.testutil.BalanceTestCase): + """Test cases for _coerce_to_numeric_and_validate function.""" + + def test_successful_numeric_conversion(self) -> None: + """Test successful conversion of numeric series.""" + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + series = pd.Series([1.0, 2.0, 3.0]) + weights = np.array([1.0, 2.0, 3.0]) + + result_vals, result_weights = _coerce_to_numeric_and_validate( + series, weights, "test" + ) + + np.testing.assert_array_equal(result_vals, np.array([1.0, 2.0, 3.0])) + np.testing.assert_array_equal(result_weights, np.array([1.0, 2.0, 3.0])) + + def test_conversion_with_coercible_strings(self) -> None: + """Test conversion when series contains numeric strings.""" + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + series = pd.Series(["1", "2", "3"]) + weights = np.array([1.0, 2.0, 3.0]) + + result_vals, result_weights = _coerce_to_numeric_and_validate( + series, weights, "test" + ) + + np.testing.assert_array_equal(result_vals, np.array([1.0, 2.0, 3.0])) + np.testing.assert_array_equal(result_weights, np.array([1.0, 2.0, 3.0])) + + def test_partial_conversion_with_some_non_numeric(self) -> None: + """Test conversion when some values can't be converted to numeric.""" + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + # Mix of convertible and non-convertible values + series = pd.Series([1.0, "abc", 3.0]) + weights = np.array([1.0, 2.0, 3.0]) + + result_vals, result_weights = _coerce_to_numeric_and_validate( + series, weights, "test" + ) + + # "abc" should be coerced to NaN and dropped + np.testing.assert_array_equal(result_vals, np.array([1.0, 3.0])) + np.testing.assert_array_equal(result_weights, np.array([1.0, 3.0])) + + def test_raises_on_all_non_numeric(self) -> None: + """Test that ValueError is raised when all values fail numeric conversion. + + This is the key test case that verifies the previously unreachable code + path is now testable. When all values in a series cannot be converted + to numeric, the function should raise ValueError. + """ + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + # All values are non-numeric strings + series = pd.Series(["abc", "def", "ghi"]) + weights = np.array([1.0, 2.0, 3.0]) + + with self.assertRaisesRegex( + ValueError, "must contain at least one valid numeric value" + ): + _coerce_to_numeric_and_validate(series, weights, "test") + + def test_raises_on_empty_series(self) -> None: + """Test that ValueError is raised when series is empty.""" + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + series = pd.Series([], dtype=float) + weights = np.array([]) + + with self.assertRaisesRegex( + ValueError, "must contain at least one valid numeric value" + ): + _coerce_to_numeric_and_validate(series, weights, "empty series") + + def test_handles_nan_values(self) -> None: + """Test that NaN values are properly dropped during conversion.""" + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + series = pd.Series([1.0, np.nan, 3.0]) + weights = np.array([1.0, 2.0, 3.0]) + + result_vals, result_weights = _coerce_to_numeric_and_validate( + series, weights, "test" + ) + + np.testing.assert_array_equal(result_vals, np.array([1.0, 3.0])) + np.testing.assert_array_equal(result_weights, np.array([1.0, 3.0])) + + def test_raises_on_all_nan_values(self) -> None: + """Test that ValueError is raised when all values are NaN after conversion.""" + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + series = pd.Series([np.nan, np.nan, np.nan]) + weights = np.array([1.0, 2.0, 3.0]) + + with self.assertRaisesRegex( + ValueError, "must contain at least one valid numeric value" + ): + _coerce_to_numeric_and_validate(series, weights, "all_nan") + + def test_preserves_weight_alignment(self) -> None: + """Test that weights are correctly aligned after dropping invalid values.""" + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + # Series with some non-convertible values at specific positions + series = pd.Series([1.0, "bad", 3.0, "bad", 5.0], index=[0, 1, 2, 3, 4]) + weights = np.array([10.0, 20.0, 30.0, 40.0, 50.0]) + + result_vals, result_weights = _coerce_to_numeric_and_validate( + series, weights, "test" + ) + + # Only positions 0, 2, 4 should remain + np.testing.assert_array_equal(result_vals, np.array([1.0, 3.0, 5.0])) + np.testing.assert_array_equal(result_weights, np.array([10.0, 30.0, 50.0])) + + def test_error_message_includes_label(self) -> None: + """Test that error message includes the provided label.""" + from balance.utils.input_validation import _coerce_to_numeric_and_validate + + series = pd.Series(["abc", "def"]) + weights = np.array([1.0, 2.0]) + + with self.assertRaisesRegex(ValueError, "my_custom_label"): + _coerce_to_numeric_and_validate(series, weights, "my_custom_label") diff --git a/tests/test_weighted_comparisons_plots.py b/tests/test_weighted_comparisons_plots.py index d8b28cd7f..a166ec006 100644 --- a/tests/test_weighted_comparisons_plots.py +++ b/tests/test_weighted_comparisons_plots.py @@ -1092,6 +1092,7 @@ def test_plot_qq_unweighted(self) -> None: """ from balance.stats_and_plots.weighted_comparisons_plots import plot_qq + np.random.seed(42) test_df = pd.DataFrame({"v1": np.random.uniform(size=50)}) fig, ax = plt.subplots(1, 1, figsize=(7.2, 7.2)) @@ -1562,6 +1563,7 @@ def test_plotly_functions_return_none_by_default(self) -> None: ) # Create test data + np.random.seed(42) test_df = pd.DataFrame( { "v1": np.random.normal(size=50), @@ -1601,3 +1603,348 @@ def test_plotly_functions_return_none_by_default(self) -> None: return_dict_of_figures=False, ) self.assertIsNone(result_bar) + + +class TestSeabornPlotDistDefaultNames(balance.testutil.BalanceTestCase): + """Test cases for seaborn_plot_dist with default names.""" + + def tearDown(self) -> None: + plt.close("all") + super().tearDown() + + def test_seaborn_plot_dist_generates_default_names(self) -> None: + """Test seaborn_plot_dist generates default names when names=None.""" + df1 = pd.DataFrame({"v1": [1.0, 2.0, 3.0]}) + df2 = pd.DataFrame({"v1": [2.0, 3.0, 4.0]}) + dfs: List[DataFrameWithWeight] = [ + {"df": df1, "weight": None}, + {"df": df2, "weight": None}, + ] + result = weighted_comparisons_plots.seaborn_plot_dist( + dfs, + names=None, + variables=["v1"], + dist_type="hist", + return_axes=True, + ) + self.assertIsNotNone(result) + + def test_seaborn_plot_dist_single_df_default_hist(self) -> None: + """Test seaborn_plot_dist with single df defaults to hist dist_type.""" + df1 = pd.DataFrame({"v1": [1.0, 2.0, 3.0]}) + dfs: List[DataFrameWithWeight] = [{"df": df1, "weight": None}] + result = weighted_comparisons_plots.seaborn_plot_dist( + dfs, + names=["self"], + variables=["v1"], + dist_type=None, + return_axes=True, + ) + self.assertIsNotNone(result) + + def test_seaborn_plot_dist_multiple_df_default_qq(self) -> None: + """Test seaborn_plot_dist with multiple dfs defaults to qq dist_type.""" + df1 = pd.DataFrame({"v1": [1.0, 2.0, 3.0]}) + df2 = pd.DataFrame({"v1": [2.0, 3.0, 4.0]}) + dfs: List[DataFrameWithWeight] = [ + {"df": df1, "weight": None}, + {"df": df2, "weight": None}, + ] + result = weighted_comparisons_plots.seaborn_plot_dist( + dfs, + names=["self", "target"], + variables=["v1"], + dist_type=None, + return_axes=True, + ) + self.assertIsNotNone(result) + + +class TestSeabornPlotDistNoNonmissing(balance.testutil.BalanceTestCase): + """Test cases for seaborn_plot_dist with no nonmissing values.""" + + def tearDown(self) -> None: + plt.close("all") + super().tearDown() + + def test_seaborn_plot_dist_skips_all_missing_variable(self) -> None: + """Test seaborn_plot_dist skips variable with no nonmissing values.""" + import logging + + df1 = pd.DataFrame({"v1": [np.nan, np.nan, np.nan]}) + df2 = pd.DataFrame({"v1": [np.nan, np.nan, np.nan]}) + dfs: List[DataFrameWithWeight] = [ + {"df": df1, "weight": None}, + {"df": df2, "weight": None}, + ] + with self.assertLogs(level=logging.WARNING) as log: + weighted_comparisons_plots.seaborn_plot_dist( + dfs, + names=["self", "target"], + variables=["v1"], + dist_type="qq", + ) + self.assertTrue(any("No nonmissing values" in msg for msg in log.output)) + + +class TestSeabornPlotDistQQNumeric(balance.testutil.BalanceTestCase): + """Test cases for seaborn_plot_dist qq with numeric variables.""" + + def tearDown(self) -> None: + plt.close("all") + super().tearDown() + + def test_seaborn_plot_dist_qq_with_numeric(self) -> None: + """Test seaborn_plot_dist uses plot_qq for numeric variables with qq dist_type.""" + np.random.seed(42) + df1 = pd.DataFrame({"v1": np.random.randn(50)}) + df2 = pd.DataFrame({"v1": np.random.randn(50)}) + dfs: List[DataFrameWithWeight] = [ + {"df": df1, "weight": None}, + {"df": df2, "weight": None}, + ] + result = weighted_comparisons_plots.seaborn_plot_dist( + dfs, + names=["self", "target"], + variables=["v1"], + dist_type="qq", + numeric_n_values_threshold=3, + return_axes=True, + ) + self.assertIsNotNone(result) + + +class TestPlotlyPlotDensityWeightsNone(balance.testutil.BalanceTestCase): + """Test cases for plotly_plot_density without weights column.""" + + def test_plotly_plot_density_uses_ones_when_no_weight(self) -> None: + """Test plotly_plot_density uses ones when no weight column present.""" + from balance.stats_and_plots.weighted_comparisons_plots import ( + plotly_plot_density, + ) + + df_no_weight = pd.DataFrame({"v1": [1.0, 2.0, 3.0, 4.0]}) + dict_of_dfs = { + "self": df_no_weight, + "target": df_no_weight.copy(), + } + result = plotly_plot_density( + dict_of_dfs, + variables=["v1"], + plot_it=False, + return_dict_of_figures=True, + ) + self.assertIsNotNone(result) + self.assertIn("v1", result) + + +class TestPlotlyPlotDensityPlotIt(balance.testutil.BalanceTestCase): + """Test cases for plotly_plot_density with plot_it=True.""" + + def test_plotly_plot_density_plot_it_true(self) -> None: + """Test plotly_plot_density with plot_it=True.""" + from unittest.mock import patch + + from balance.stats_and_plots.weighted_comparisons_plots import ( + plotly_plot_density, + ) + + df = pd.DataFrame({"v1": [1.0, 2.0, 3.0], "weight": [1.0, 1.0, 1.0]}) + dict_of_dfs = {"self": df, "target": df.copy()} + + with patch("plotly.offline.iplot") as mock_iplot: + plotly_plot_density( + dict_of_dfs, + variables=["v1"], + plot_it=True, + return_dict_of_figures=False, + ) + mock_iplot.assert_called() + + +class TestPlotlyPlotQQPlotIt(balance.testutil.BalanceTestCase): + """Test cases for plotly_plot_qq with plot_it=True.""" + + def test_plotly_plot_qq_plot_it_true(self) -> None: + """Test plotly_plot_qq with plot_it=True.""" + from unittest.mock import patch + + from balance.stats_and_plots.weighted_comparisons_plots import plotly_plot_qq + + np.random.seed(42) + df = pd.DataFrame({"v1": np.random.randn(50), "weight": np.ones(50)}) + dict_of_dfs = {"self": df, "target": df.copy()} + + with patch("plotly.offline.iplot") as mock_iplot: + plotly_plot_qq( + dict_of_dfs, + variables=["v1"], + plot_it=True, + return_dict_of_figures=False, + ) + mock_iplot.assert_called() + + +class TestPlotlyPlotDistNoSampleKey(balance.testutil.BalanceTestCase): + """Test cases for plotly_plot_dist without 'sample' key.""" + + def test_plotly_plot_dist_random_key_for_numeric(self) -> None: + """Test plotly_plot_dist uses random key when 'sample' not present.""" + from balance.stats_and_plots.weighted_comparisons_plots import plotly_plot_dist + + np.random.seed(42) + df = pd.DataFrame({"v1": np.random.randn(30), "weight": np.ones(30)}) + dict_of_dfs = {"self": df, "target": df.copy()} + result = plotly_plot_dist( + dict_of_dfs, + variables=["v1"], + dist_type="kde", + plot_it=False, + return_dict_of_figures=True, + ) + self.assertIsNotNone(result) + + +class TestPlotlyPlotDistWithSampleKey(balance.testutil.BalanceTestCase): + """Test cases for plotly_plot_dist with 'sample' key present.""" + + def test_plotly_plot_dist_uses_sample_key_for_numeric(self) -> None: + """Test plotly_plot_dist uses 'sample' key when present for numeric vars. + + Verifies line 1268 in weighted_comparisons_plots.py. + """ + from balance.stats_and_plots.weighted_comparisons_plots import plotly_plot_dist + + np.random.seed(42) + df = pd.DataFrame({"v1": np.random.randn(30), "weight": np.ones(30)}) + dict_of_dfs = {"sample": df, "target": df.copy()} + result = plotly_plot_dist( + dict_of_dfs, + variables=["v1"], + dist_type="kde", + plot_it=False, + return_dict_of_figures=True, + ) + self.assertIsNotNone(result) + self.assertIn("v1", result) + + +class TestPlotlyPlotDistNoNonmissing(balance.testutil.BalanceTestCase): + """Test cases for plotly_plot_dist with no nonmissing values.""" + + def test_plotly_plot_dist_skips_all_missing_variable(self) -> None: + """Test plotly_plot_dist skips variable with no nonmissing values.""" + import logging + + from balance.stats_and_plots.weighted_comparisons_plots import plotly_plot_dist + + df = pd.DataFrame({"v1": [np.nan, np.nan, np.nan], "weight": [1.0, 1.0, 1.0]}) + dict_of_dfs = {"self": df, "target": df.copy()} + with self.assertLogs(level=logging.WARNING) as log: + plotly_plot_dist( + dict_of_dfs, + variables=["v1"], + dist_type="kde", + plot_it=False, + return_dict_of_figures=True, + ) + self.assertTrue(any("No nonmissing values" in msg for msg in log.output)) + + +class TestPlotlyPlotDistUnsupportedDistType(balance.testutil.BalanceTestCase): + """Test cases for plotly_plot_dist with unsupported dist_type.""" + + def test_plotly_plot_dist_raises_on_unsupported_dist_type(self) -> None: + """Test plotly_plot_dist raises NotImplementedError for unsupported dist_type.""" + from balance.stats_and_plots.weighted_comparisons_plots import plotly_plot_dist + + df = pd.DataFrame({"v1": [1.0, 2.0, 3.0], "weight": [1.0, 1.0, 1.0]}) + dict_of_dfs = {"self": df, "target": df.copy()} + with self.assertRaises(NotImplementedError): + plotly_plot_dist( + dict_of_dfs, + variables=["v1"], + # pyre-ignore[6]: Testing invalid dist_type intentionally + dist_type="unknown_type", + plot_it=False, + return_dict_of_figures=True, + ) + + +class TestPlotlyPlotDistQQDistType(balance.testutil.BalanceTestCase): + """Test cases for plotly_plot_dist with qq dist_type.""" + + def test_plotly_plot_dist_with_qq_dist_type(self) -> None: + """Test plotly_plot_dist uses plotly_plot_qq for qq dist_type.""" + from balance.stats_and_plots.weighted_comparisons_plots import plotly_plot_dist + + np.random.seed(42) + df = pd.DataFrame({"v1": np.random.randn(30), "weight": np.ones(30)}) + dict_of_dfs = {"self": df, "target": df.copy()} + result = plotly_plot_dist( + dict_of_dfs, + variables=["v1"], + dist_type="qq", + plot_it=False, + return_dict_of_figures=True, + ) + self.assertIsNotNone(result) + self.assertIn("v1", result) + + +class TestPlotDistPlotlyQQDistType(balance.testutil.BalanceTestCase): + """Test cases for plot_dist with plotly library and qq dist_type.""" + + def test_plot_dist_plotly_with_qq_dist_type(self) -> None: + """Test plot_dist with plotly library and qq dist_type.""" + np.random.seed(42) + df = pd.DataFrame({"v1": np.random.randn(30)}) + dfs: List[DataFrameWithWeight] = [ + {"df": df, "weight": pd.Series(np.ones(30))}, + {"df": df.copy(), "weight": pd.Series(np.ones(30))}, + ] + result = weighted_comparisons_plots.plot_dist( + dfs, + names=["self", "target"], + library="plotly", + dist_type="qq", + plot_it=False, + return_dict_of_figures=True, + ) + self.assertIsNotNone(result) + + +class TestPlotDistPlotlyUnsupportedDistType(balance.testutil.BalanceTestCase): + """Test cases for plot_dist with plotly library and unsupported dist_type.""" + + def test_plot_dist_plotly_raises_on_hist_dist_type(self) -> None: + """Test plot_dist with plotly library raises on hist dist_type.""" + df = pd.DataFrame({"v1": [1.0, 2.0, 3.0]}) + dfs: List[DataFrameWithWeight] = [ + {"df": df, "weight": None}, + {"df": df.copy(), "weight": None}, + ] + with self.assertRaises(ValueError) as context: + weighted_comparisons_plots.plot_dist( + dfs, + names=["self", "target"], + library="plotly", + dist_type="hist", + ) + self.assertIn("plotly library does not support", str(context.exception)) + + def test_plot_dist_plotly_raises_on_ecdf_dist_type(self) -> None: + """Test plot_dist with plotly library raises on ecdf dist_type.""" + df = pd.DataFrame({"v1": [1.0, 2.0, 3.0]}) + dfs: List[DataFrameWithWeight] = [ + {"df": df, "weight": None}, + {"df": df.copy(), "weight": None}, + ] + with self.assertRaises(ValueError) as context: + weighted_comparisons_plots.plot_dist( + dfs, + names=["self", "target"], + library="plotly", + dist_type="ecdf", + ) + self.assertIn("plotly library does not support", str(context.exception))