-
Notifications
You must be signed in to change notification settings - Fork 51
Route unmentioned columns to ignore_columns #333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -280,6 +280,15 @@ def batch_columns(self) -> List[str]: | |
| def has_keep_columns(self) -> bool: | ||
| """Return True when output keep columns are supplied. | ||
|
|
||
| Keep columns control which columns appear in the final output CSV. | ||
| After adjustment, the output DataFrame is subsetted to contain | ||
| **only** these columns (see :meth:`adapt_output`). | ||
|
|
||
| Note that keep columns that are not the id, weight, a covariate, | ||
| or an explicit outcome column will be placed into | ||
| ``ignore_columns`` by :meth:`process_batch`. They are still | ||
| carried through the ``Sample`` and available in the output. | ||
|
|
||
| Returns: | ||
| ``True`` if keep columns are set, otherwise ``False``. | ||
|
|
||
|
|
@@ -294,6 +303,12 @@ def has_keep_columns(self) -> bool: | |
| def keep_columns(self) -> List[str] | None: | ||
| """Return the subset of columns to keep in outputs. | ||
|
|
||
| These columns are used to filter the final output DataFrame. | ||
| Keep columns that are not the id, weight, a covariate, or an | ||
| explicit outcome column will be placed into ``ignore_columns`` | ||
| during processing but are still retained by the ``Sample`` and | ||
| included in the output. | ||
|
|
||
| Returns: | ||
| List of columns to keep or ``None`` if unspecified. | ||
|
|
||
|
|
@@ -667,33 +682,22 @@ def process_batch( | |
| ), | ||
| } | ||
|
|
||
| # Stuff everything that is not id, weight, or covariate into outcomes | ||
| # Build the set of explicitly mentioned columns. Any column not in | ||
| # this set is placed into ignore_columns so it is carried through | ||
| # the Sample but does not participate in the adjustment. | ||
| outcome_columns = self.outcome_columns() | ||
| ignore_columns = None | ||
| if outcome_columns is None: | ||
| outcome_columns = [ | ||
| column | ||
| for column in batch_df.columns | ||
| if column | ||
| not in { | ||
| self.id_column(), | ||
| self.weight_column(), | ||
| *self.covariate_columns(), | ||
| } | ||
| ] | ||
| else: | ||
| ignore_columns = [ | ||
| column | ||
| for column in batch_df.columns | ||
| if column | ||
| not in { | ||
| self.id_column(), | ||
| self.weight_column(), | ||
| *self.covariate_columns(), | ||
| *outcome_columns, | ||
| } | ||
| ] | ||
| outcome_columns = tuple(outcome_columns) | ||
| explicitly_mentioned: set[str] = { | ||
| self.id_column(), | ||
| self.weight_column(), | ||
| *self.covariate_columns(), | ||
| } | ||
| if outcome_columns is not None: | ||
| explicitly_mentioned.update(outcome_columns) | ||
|
|
||
| ignore_columns = [ | ||
| column for column in batch_df.columns if column not in explicitly_mentioned | ||
| ] | ||
| outcome_columns = tuple(outcome_columns) if outcome_columns else None | ||
talgalili marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # definitions for diagnostics | ||
| covariate_columns_for_diagnostics = self.covariate_columns_for_diagnostics() | ||
|
|
@@ -1218,8 +1222,9 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser: | |
| required=False, | ||
| default=None, | ||
| help=( | ||
| "Set of columns used as outcomes. If not supplied, all columns that are " | ||
| "not in id, weight, or covariate columns are treated as outcomes." | ||
| "Comma-separated columns used as outcomes. If not supplied, " | ||
| "columns that are not id, weight, or covariates are placed into " | ||
| "ignore_columns (carried through but not used in adjustment)." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
|
|
@@ -1255,7 +1260,10 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser: | |
| "--keep_columns", | ||
| type=str, | ||
| required=False, | ||
| help="Set of columns we include in the output csv file", | ||
| help=( | ||
| "Comma-separated columns to include in the output csv file. " | ||
| "The output will be subsetted to only these columns." | ||
| ), | ||
|
Comment on lines
1260
to
+1266
|
||
| ) | ||
| parser.add_argument( | ||
| "--keep_row_column", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -147,7 +147,7 @@ def _make_batch_df(self) -> pd.DataFrame: | |
| # pyre-ignore[3]: Intentionally returning a dynamically created class | ||
| def _recording_sample_cls(self): | ||
| class RecordingSample: | ||
| calls: list[tuple[str, ...]] = [] | ||
| calls: list[tuple[str, ...] | None] = [] | ||
| ignore_calls: list[list[str] | None] = [] | ||
|
|
||
| def __init__(self, df: pd.DataFrame) -> None: | ||
|
|
@@ -164,8 +164,6 @@ def from_frame( | |
| ignore_columns: list[str] | None = None, | ||
| **kwargs: object, | ||
| ) -> "RecordingSample": | ||
| if outcome_columns is None: | ||
| raise AssertionError("Expected outcome_columns to be provided.") | ||
| cls.calls.append(outcome_columns) | ||
| cls.ignore_calls.append(ignore_columns) | ||
| return cls(df) | ||
|
|
@@ -190,7 +188,7 @@ def diagnostics( | |
|
|
||
| return RecordingSample | ||
|
|
||
| def test_cli_outcome_columns_default_inference_preserves_order(self) -> None: | ||
| def test_cli_unmentioned_columns_go_to_ignore(self) -> None: | ||
| RecordingSample = self._recording_sample_cls() | ||
| cli = self._make_cli() | ||
| cli.process_batch( | ||
|
|
@@ -200,12 +198,15 @@ def test_cli_outcome_columns_default_inference_preserves_order(self) -> None: | |
| ) | ||
| self.assertEqual( | ||
| RecordingSample.calls, | ||
| [None, None], | ||
| ) | ||
| self.assertEqual( | ||
| RecordingSample.ignore_calls, | ||
| [ | ||
| ("is_respondent", "outcome_b", "outcome_a", "extra"), | ||
| ("is_respondent", "outcome_b", "outcome_a", "extra"), | ||
| ["is_respondent", "outcome_b", "outcome_a", "extra"], | ||
| ["is_respondent", "outcome_b", "outcome_a", "extra"], | ||
| ], | ||
| ) | ||
| self.assertEqual(RecordingSample.ignore_calls, [None, None]) | ||
|
|
||
| def test_cli_outcome_columns_explicit_selection(self) -> None: | ||
| RecordingSample = self._recording_sample_cls() | ||
|
|
@@ -1324,6 +1325,52 @@ def test_check_input_columns_raises_when_keep_column_missing(self) -> None: | |
| with self.assertRaises(AssertionError): | ||
| cli.check_input_columns(columns) | ||
|
|
||
| def test_keep_columns_preserved_in_adjusted_output(self) -> None: | ||
| """Test that --keep_columns columns survive adjustment via ignore_columns. | ||
|
|
||
| A keep column that is not id, weight, covariate, or outcome should be | ||
| routed to ignore_columns by process_batch, carried through the Sample, | ||
| and available for adapt_output to subset without KeyError. | ||
| """ | ||
| with ( | ||
| tempfile.TemporaryDirectory() as temp_dir, | ||
| tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file, | ||
| ): | ||
|
Comment on lines
+1335
to
+1338
|
||
| in_contents = ( | ||
| "x,y,is_respondent,id,weight,extra_col\n" | ||
| + ("1.0,50,1,1,1,abc\n" * 100) | ||
| + ("2.0,60,0,1,1,def\n" * 100) | ||
| ) | ||
| in_file.write(in_contents) | ||
| in_file.close() | ||
| out_file = os.path.join(temp_dir, "out.csv") | ||
|
|
||
| parser = make_parser() | ||
| args = parser.parse_args( | ||
| [ | ||
| "--input_file", | ||
| in_file.name, | ||
| "--output_file", | ||
| out_file, | ||
| "--covariate_columns", | ||
| "x,y", | ||
| "--keep_columns", | ||
| "id,weight,extra_col", | ||
| ] | ||
| ) | ||
| cli = BalanceCLI(args) | ||
| cli.update_attributes_for_main_used_by_adjust() | ||
| cli.main() | ||
|
|
||
| self.assertTrue(os.path.isfile(out_file)) | ||
| pd_out = pd.read_csv(out_file) | ||
| # adapt_output should have subsetted to exactly these columns | ||
| self.assertEqual( | ||
| sorted(pd_out.columns.tolist()), ["extra_col", "id", "weight"] | ||
| ) | ||
| # extra_col values should be preserved from the sample rows | ||
| self.assertTrue((pd_out["extra_col"] == "abc").all()) | ||
|
|
||
|
|
||
| class TestBalanceCLI_num_lambdas(balance.testutil.BalanceTestCase): | ||
| """Test cases for num_lambdas method (line 401).""" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.