diff --git a/CHANGELOG.md b/CHANGELOG.md index 84df9b04..dab07165 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,24 @@ +# 0.17.0 (TBD) + +## Breaking Changes + +- **CLI: unmentioned columns now go to `ignore_columns` instead of `outcome_columns`** + - Previously, when `--outcome_columns` was not explicitly set, all columns that + were not the id, weight, or a covariate were automatically classified as + outcome columns. Now those columns are placed into `ignore_columns` instead. + - Columns that are explicitly mentioned — the id column, weight column, + covariate columns, and outcome columns — are **not** ignored. + +## Documentation + +- **Improved `keep_columns` documentation** + - Updated docstrings for `has_keep_columns()`, `keep_columns()`, and the + `--keep_columns` argument to clarify that keep columns control which columns + appear in the final output CSV. Keep columns that are not id, weight, + covariate, or outcome columns will be placed into ``ignore_columns`` during + processing but are still retained and available in the output. + + # 0.16.0 (2026-02-09) ## New Features diff --git a/balance/__init__.py b/balance/__init__.py index 19f7ac67..9810bb01 100644 --- a/balance/__init__.py +++ b/balance/__init__.py @@ -19,7 +19,7 @@ from balance.util import TruncationFormatter # noqa global __version__ -__version__ = "0.16.0" +__version__ = "0.16.1" WELCOME_MESSAGE = f""" balance (Version {__version__}) loaded: diff --git a/balance/cli.py b/balance/cli.py index 249ebb55..4f5ba92e 100644 --- a/balance/cli.py +++ b/balance/cli.py @@ -280,6 +280,15 @@ def batch_columns(self) -> List[str]: def has_keep_columns(self) -> bool: """Return True when output keep columns are supplied. + Keep columns control which columns appear in the final output CSV. + After adjustment, the output DataFrame is subsetted to contain + **only** these columns (see :meth:`adapt_output`). + + Note that keep columns that are not the id, weight, a covariate, + or an explicit outcome column will be placed into + ``ignore_columns`` by :meth:`process_batch`. They are still + carried through the ``Sample`` and available in the output. + Returns: ``True`` if keep columns are set, otherwise ``False``. @@ -294,6 +303,12 @@ def has_keep_columns(self) -> bool: def keep_columns(self) -> List[str] | None: """Return the subset of columns to keep in outputs. + These columns are used to filter the final output DataFrame. + Keep columns that are not the id, weight, a covariate, or an + explicit outcome column will be placed into ``ignore_columns`` + during processing but are still retained by the ``Sample`` and + included in the output. + Returns: List of columns to keep or ``None`` if unspecified. @@ -667,33 +682,22 @@ def process_batch( ), } - # Stuff everything that is not id, weight, or covariate into outcomes + # Build the set of explicitly mentioned columns. Any column not in + # this set is placed into ignore_columns so it is carried through + # the Sample but does not participate in the adjustment. outcome_columns = self.outcome_columns() - ignore_columns = None - if outcome_columns is None: - outcome_columns = [ - column - for column in batch_df.columns - if column - not in { - self.id_column(), - self.weight_column(), - *self.covariate_columns(), - } - ] - else: - ignore_columns = [ - column - for column in batch_df.columns - if column - not in { - self.id_column(), - self.weight_column(), - *self.covariate_columns(), - *outcome_columns, - } - ] - outcome_columns = tuple(outcome_columns) + explicitly_mentioned: set[str] = { + self.id_column(), + self.weight_column(), + *self.covariate_columns(), + } + if outcome_columns is not None: + explicitly_mentioned.update(outcome_columns) + + ignore_columns = [ + column for column in batch_df.columns if column not in explicitly_mentioned + ] + outcome_columns = tuple(outcome_columns) if outcome_columns else None # definitions for diagnostics covariate_columns_for_diagnostics = self.covariate_columns_for_diagnostics() @@ -1218,8 +1222,9 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser: required=False, default=None, help=( - "Set of columns used as outcomes. If not supplied, all columns that are " - "not in id, weight, or covariate columns are treated as outcomes." + "Comma-separated columns used as outcomes. If not supplied, " + "columns that are not id, weight, or covariates are placed into " + "ignore_columns (carried through but not used in adjustment)." ), ) parser.add_argument( @@ -1255,7 +1260,10 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser: "--keep_columns", type=str, required=False, - help="Set of columns we include in the output csv file", + help=( + "Comma-separated columns to include in the output csv file. " + "The output will be subsetted to only these columns." + ), ) parser.add_argument( "--keep_row_column", diff --git a/tests/test_cli.py b/tests/test_cli.py index 18c44ac1..9b81d485 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -147,7 +147,7 @@ def _make_batch_df(self) -> pd.DataFrame: # pyre-ignore[3]: Intentionally returning a dynamically created class def _recording_sample_cls(self): class RecordingSample: - calls: list[tuple[str, ...]] = [] + calls: list[tuple[str, ...] | None] = [] ignore_calls: list[list[str] | None] = [] def __init__(self, df: pd.DataFrame) -> None: @@ -164,8 +164,6 @@ def from_frame( ignore_columns: list[str] | None = None, **kwargs: object, ) -> "RecordingSample": - if outcome_columns is None: - raise AssertionError("Expected outcome_columns to be provided.") cls.calls.append(outcome_columns) cls.ignore_calls.append(ignore_columns) return cls(df) @@ -190,7 +188,7 @@ def diagnostics( return RecordingSample - def test_cli_outcome_columns_default_inference_preserves_order(self) -> None: + def test_cli_unmentioned_columns_go_to_ignore(self) -> None: RecordingSample = self._recording_sample_cls() cli = self._make_cli() cli.process_batch( @@ -200,12 +198,15 @@ def test_cli_outcome_columns_default_inference_preserves_order(self) -> None: ) self.assertEqual( RecordingSample.calls, + [None, None], + ) + self.assertEqual( + RecordingSample.ignore_calls, [ - ("is_respondent", "outcome_b", "outcome_a", "extra"), - ("is_respondent", "outcome_b", "outcome_a", "extra"), + ["is_respondent", "outcome_b", "outcome_a", "extra"], + ["is_respondent", "outcome_b", "outcome_a", "extra"], ], ) - self.assertEqual(RecordingSample.ignore_calls, [None, None]) def test_cli_outcome_columns_explicit_selection(self) -> None: RecordingSample = self._recording_sample_cls() @@ -1324,6 +1325,52 @@ def test_check_input_columns_raises_when_keep_column_missing(self) -> None: with self.assertRaises(AssertionError): cli.check_input_columns(columns) + def test_keep_columns_preserved_in_adjusted_output(self) -> None: + """Test that --keep_columns columns survive adjustment via ignore_columns. + + A keep column that is not id, weight, covariate, or outcome should be + routed to ignore_columns by process_batch, carried through the Sample, + and available for adapt_output to subset without KeyError. + """ + with ( + tempfile.TemporaryDirectory() as temp_dir, + tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file, + ): + in_contents = ( + "x,y,is_respondent,id,weight,extra_col\n" + + ("1.0,50,1,1,1,abc\n" * 100) + + ("2.0,60,0,1,1,def\n" * 100) + ) + in_file.write(in_contents) + in_file.close() + out_file = os.path.join(temp_dir, "out.csv") + + parser = make_parser() + args = parser.parse_args( + [ + "--input_file", + in_file.name, + "--output_file", + out_file, + "--covariate_columns", + "x,y", + "--keep_columns", + "id,weight,extra_col", + ] + ) + cli = BalanceCLI(args) + cli.update_attributes_for_main_used_by_adjust() + cli.main() + + self.assertTrue(os.path.isfile(out_file)) + pd_out = pd.read_csv(out_file) + # adapt_output should have subsetted to exactly these columns + self.assertEqual( + sorted(pd_out.columns.tolist()), ["extra_col", "id", "weight"] + ) + # extra_col values should be preserved from the sample rows + self.assertTrue((pd_out["extra_col"] == "abc").all()) + class TestBalanceCLI_num_lambdas(balance.testutil.BalanceTestCase): """Test cases for num_lambdas method (line 401)."""