Skip to content

Commit 18148fc

Browse files
talgalilifacebook-github-bot
authored andcommitted
Route unmentioned columns to ignore_columns (#333)
Summary: Previously, when --outcome_columns was not explicitly set, all columns that were not the id, weight, or a covariate were automatically classified as outcome_columns. This was surprising — e.g. columns listed in --keep_columns would silently become outcome columns. Now, unmentioned columns are placed into ignore_columns instead. Columns that are explicitly mentioned (id, weight, covariates, and any columns specified via --outcome_columns, --keep_columns, --batch_columns, or --keep_row_column) are not ignored. Pipelines that relied on the automatic outcome classification must now pass those column names explicitly via --outcome_columns. Also improved docstrings for has_keep_columns(), keep_columns(), and the --keep_columns argument to clarify their behavior. Differential Revision: D92749122
1 parent 330b0a1 commit 18148fc

File tree

4 files changed

+113
-37
lines changed

4 files changed

+113
-37
lines changed

CHANGELOG.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
1+
# 0.17.0 (TBD)
2+
3+
## Breaking Changes
4+
5+
- **CLI: unmentioned columns now go to `ignore_columns` instead of `outcome_columns`**
6+
- Previously, when `--outcome_columns` was not explicitly set, all columns that
7+
were not the id, weight, or a covariate were automatically classified as
8+
outcome columns. Now those columns are placed into `ignore_columns` instead.
9+
- Columns that are explicitly mentioned — the id column, weight column,
10+
covariate columns, and outcome columns — are **not** ignored.
11+
12+
## Documentation
13+
14+
- **Improved `keep_columns` documentation**
15+
- Updated docstrings for `has_keep_columns()`, `keep_columns()`, and the
16+
`--keep_columns` argument to clarify that keep columns control which columns
17+
appear in the final output CSV. Keep columns that are not id, weight,
18+
covariate, or outcome columns will be placed into ``ignore_columns`` during
19+
processing but are still retained and available in the output.
20+
21+
122
# 0.16.0 (2026-02-09)
223

324
## New Features

balance/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from balance.util import TruncationFormatter # noqa
2020

2121
global __version__
22-
__version__ = "0.16.0"
22+
__version__ = "0.16.1"
2323

2424
WELCOME_MESSAGE = f"""
2525
balance (Version {__version__}) loaded:

balance/cli.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,15 @@ def batch_columns(self) -> List[str]:
280280
def has_keep_columns(self) -> bool:
281281
"""Return True when output keep columns are supplied.
282282
283+
Keep columns control which columns appear in the final output CSV.
284+
After adjustment, the output DataFrame is subsetted to contain
285+
**only** these columns (see :meth:`adapt_output`).
286+
287+
Note that keep columns that are not the id, weight, a covariate,
288+
or an explicit outcome column will be placed into
289+
``ignore_columns`` by :meth:`process_batch`. They are still
290+
carried through the ``Sample`` and available in the output.
291+
283292
Returns:
284293
``True`` if keep columns are set, otherwise ``False``.
285294
@@ -294,6 +303,12 @@ def has_keep_columns(self) -> bool:
294303
def keep_columns(self) -> List[str] | None:
295304
"""Return the subset of columns to keep in outputs.
296305
306+
These columns are used to filter the final output DataFrame.
307+
Keep columns that are not the id, weight, a covariate, or an
308+
explicit outcome column will be placed into ``ignore_columns``
309+
during processing but are still retained by the ``Sample`` and
310+
included in the output.
311+
297312
Returns:
298313
List of columns to keep or ``None`` if unspecified.
299314
@@ -667,33 +682,22 @@ def process_batch(
667682
),
668683
}
669684

670-
# Stuff everything that is not id, weight, or covariate into outcomes
685+
# Build the set of explicitly mentioned columns. Any column not in
686+
# this set is placed into ignore_columns so it is carried through
687+
# the Sample but does not participate in the adjustment.
671688
outcome_columns = self.outcome_columns()
672-
ignore_columns = None
673-
if outcome_columns is None:
674-
outcome_columns = [
675-
column
676-
for column in batch_df.columns
677-
if column
678-
not in {
679-
self.id_column(),
680-
self.weight_column(),
681-
*self.covariate_columns(),
682-
}
683-
]
684-
else:
685-
ignore_columns = [
686-
column
687-
for column in batch_df.columns
688-
if column
689-
not in {
690-
self.id_column(),
691-
self.weight_column(),
692-
*self.covariate_columns(),
693-
*outcome_columns,
694-
}
695-
]
696-
outcome_columns = tuple(outcome_columns)
689+
explicitly_mentioned: set[str] = {
690+
self.id_column(),
691+
self.weight_column(),
692+
*self.covariate_columns(),
693+
}
694+
if outcome_columns is not None:
695+
explicitly_mentioned.update(outcome_columns)
696+
697+
ignore_columns = [
698+
column for column in batch_df.columns if column not in explicitly_mentioned
699+
]
700+
outcome_columns = tuple(outcome_columns) if outcome_columns else None
697701

698702
# definitions for diagnostics
699703
covariate_columns_for_diagnostics = self.covariate_columns_for_diagnostics()
@@ -1218,8 +1222,9 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser:
12181222
required=False,
12191223
default=None,
12201224
help=(
1221-
"Set of columns used as outcomes. If not supplied, all columns that are "
1222-
"not in id, weight, or covariate columns are treated as outcomes."
1225+
"Comma-separated columns used as outcomes. If not supplied, "
1226+
"columns that are not id, weight, or covariates are placed into "
1227+
"ignore_columns (carried through but not used in adjustment)."
12231228
),
12241229
)
12251230
parser.add_argument(
@@ -1255,7 +1260,10 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser:
12551260
"--keep_columns",
12561261
type=str,
12571262
required=False,
1258-
help="Set of columns we include in the output csv file",
1263+
help=(
1264+
"Comma-separated columns to include in the output csv file. "
1265+
"The output will be subsetted to only these columns."
1266+
),
12591267
)
12601268
parser.add_argument(
12611269
"--keep_row_column",

tests/test_cli.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _make_batch_df(self) -> pd.DataFrame:
147147
# pyre-ignore[3]: Intentionally returning a dynamically created class
148148
def _recording_sample_cls(self):
149149
class RecordingSample:
150-
calls: list[tuple[str, ...]] = []
150+
calls: list[tuple[str, ...] | None] = []
151151
ignore_calls: list[list[str] | None] = []
152152

153153
def __init__(self, df: pd.DataFrame) -> None:
@@ -164,8 +164,6 @@ def from_frame(
164164
ignore_columns: list[str] | None = None,
165165
**kwargs: object,
166166
) -> "RecordingSample":
167-
if outcome_columns is None:
168-
raise AssertionError("Expected outcome_columns to be provided.")
169167
cls.calls.append(outcome_columns)
170168
cls.ignore_calls.append(ignore_columns)
171169
return cls(df)
@@ -190,7 +188,7 @@ def diagnostics(
190188

191189
return RecordingSample
192190

193-
def test_cli_outcome_columns_default_inference_preserves_order(self) -> None:
191+
def test_cli_unmentioned_columns_go_to_ignore(self) -> None:
194192
RecordingSample = self._recording_sample_cls()
195193
cli = self._make_cli()
196194
cli.process_batch(
@@ -200,12 +198,15 @@ def test_cli_outcome_columns_default_inference_preserves_order(self) -> None:
200198
)
201199
self.assertEqual(
202200
RecordingSample.calls,
201+
[None, None],
202+
)
203+
self.assertEqual(
204+
RecordingSample.ignore_calls,
203205
[
204-
("is_respondent", "outcome_b", "outcome_a", "extra"),
205-
("is_respondent", "outcome_b", "outcome_a", "extra"),
206+
["is_respondent", "outcome_b", "outcome_a", "extra"],
207+
["is_respondent", "outcome_b", "outcome_a", "extra"],
206208
],
207209
)
208-
self.assertEqual(RecordingSample.ignore_calls, [None, None])
209210

210211
def test_cli_outcome_columns_explicit_selection(self) -> None:
211212
RecordingSample = self._recording_sample_cls()
@@ -1324,6 +1325,52 @@ def test_check_input_columns_raises_when_keep_column_missing(self) -> None:
13241325
with self.assertRaises(AssertionError):
13251326
cli.check_input_columns(columns)
13261327

1328+
def test_keep_columns_preserved_in_adjusted_output(self) -> None:
1329+
"""Test that --keep_columns columns survive adjustment via ignore_columns.
1330+
1331+
A keep column that is not id, weight, covariate, or outcome should be
1332+
routed to ignore_columns by process_batch, carried through the Sample,
1333+
and available for adapt_output to subset without KeyError.
1334+
"""
1335+
with (
1336+
tempfile.TemporaryDirectory() as temp_dir,
1337+
tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False) as in_file,
1338+
):
1339+
in_contents = (
1340+
"x,y,is_respondent,id,weight,extra_col\n"
1341+
+ ("1.0,50,1,1,1,abc\n" * 100)
1342+
+ ("2.0,60,0,1,1,def\n" * 100)
1343+
)
1344+
in_file.write(in_contents)
1345+
in_file.close()
1346+
out_file = os.path.join(temp_dir, "out.csv")
1347+
1348+
parser = make_parser()
1349+
args = parser.parse_args(
1350+
[
1351+
"--input_file",
1352+
in_file.name,
1353+
"--output_file",
1354+
out_file,
1355+
"--covariate_columns",
1356+
"x,y",
1357+
"--keep_columns",
1358+
"id,weight,extra_col",
1359+
]
1360+
)
1361+
cli = BalanceCLI(args)
1362+
cli.update_attributes_for_main_used_by_adjust()
1363+
cli.main()
1364+
1365+
self.assertTrue(os.path.isfile(out_file))
1366+
pd_out = pd.read_csv(out_file)
1367+
# adapt_output should have subsetted to exactly these columns
1368+
self.assertEqual(
1369+
sorted(pd_out.columns.tolist()), ["extra_col", "id", "weight"]
1370+
)
1371+
# extra_col values should be preserved from the sample rows
1372+
self.assertTrue((pd_out["extra_col"] == "abc").all())
1373+
13271374

13281375
class TestBalanceCLI_num_lambdas(balance.testutil.BalanceTestCase):
13291376
"""Test cases for num_lambdas method (line 401)."""

0 commit comments

Comments
 (0)