Skip to content

Commit 6b77ce6

Browse files
neuralsorcerermeta-codesync[bot]
authored andcommitted
Allow configuring IPW logistic regression (#138)
Summary: Changes: - Added an optional `logistic_regression_kwargs` argument to `ipw` so callers can override LogisticRegression settings while retaining sensible defaults for existing workflows. - Extended the CLI to parse `--ipw_logistic_regression_kwargs`, persist the parsed options, and forward them through batch processing to IPW adjustments. - Added a regression test confirming that solver overrides supplied through the new keyword hook reach the fitted LogisticRegression model. Why? - Closes #130 Pull Request resolved: #138 Reviewed By: wesleytlee Differential Revision: D86750632 Pulled By: talgalili fbshipit-source-id: 12405247db26cc53ba7e932f4e14ac23008bb62d
1 parent 21d657a commit 6b77ce6

File tree

3 files changed

+182
-100
lines changed

3 files changed

+182
-100
lines changed

balance/cli.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from __future__ import absolute_import, division, print_function, unicode_literals
99

1010
import inspect
11+
import json
1112
import logging
1213

1314
from argparse import ArgumentParser, Namespace
1415
from pathlib import Path
1516

16-
from typing import Dict, List, Optional, Tuple, Type, Union
17+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1718

1819
import balance
1920

@@ -40,10 +41,25 @@ def __init__(self, args) -> None:
4041
self._lambda_max,
4142
self._num_lambdas,
4243
self._weight_trimming_mean_ratio,
44+
self._logistic_regression_kwargs,
4345
self._sample_cls,
4446
self._sample_package_name,
4547
self._sample_package_version,
46-
) = (None, None, None, None, None, None, None, None, None, None, None, None)
48+
) = (
49+
None,
50+
None,
51+
None,
52+
None,
53+
None,
54+
None,
55+
None,
56+
None,
57+
None,
58+
None,
59+
None,
60+
None,
61+
None,
62+
)
4763

4864
def check_input_columns(self, columns: Union[List[str], pd.Index]) -> None:
4965
needed_columns = []
@@ -133,6 +149,24 @@ def standardize_types(self) -> bool:
133149
def weight_trimming_mean_ratio(self) -> float:
134150
return self.args.weight_trimming_mean_ratio
135151

152+
def logistic_regression_kwargs(self) -> Optional[Dict[str, Any]]:
153+
raw_kwargs = self.args.ipw_logistic_regression_kwargs
154+
if raw_kwargs is None:
155+
return None
156+
if isinstance(raw_kwargs, dict):
157+
return raw_kwargs
158+
try:
159+
parsed = json.loads(raw_kwargs)
160+
except json.JSONDecodeError as exc:
161+
raise ValueError(
162+
"--ipw_logistic_regression_kwargs must be a JSON object string"
163+
) from exc
164+
if not isinstance(parsed, dict):
165+
raise ValueError(
166+
"--ipw_logistic_regression_kwargs must decode to a JSON object"
167+
)
168+
return parsed
169+
136170
def split_sample(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
137171
in_sample = df[self.sample_column()] == 1
138172
sample_df = df[in_sample]
@@ -152,6 +186,7 @@ def process_batch(
152186
lambda_max: Optional[float] = 10,
153187
num_lambdas: Optional[int] = 250,
154188
weight_trimming_mean_ratio: float = 20,
189+
logistic_regression_kwargs: Optional[Dict[str, Any]] = None,
155190
sample_cls: Type[balance_sample_cls] = balance_sample_cls,
156191
sample_package_name: str = __package__,
157192
) -> Dict[str, pd.DataFrame]:
@@ -215,6 +250,7 @@ def process_batch(
215250
lambda_max=lambda_max,
216251
num_lambdas=num_lambdas,
217252
weight_trimming_mean_ratio=weight_trimming_mean_ratio,
253+
logistic_regression_kwargs=logistic_regression_kwargs,
218254
)
219255
logger.info("Succeeded with adjusting sample to target")
220256
logger.info("%s adjusted object: %s" % (sample_package_name, str(adjusted)))
@@ -353,6 +389,7 @@ def update_attributes_for_main_used_by_adjust(self) -> None:
353389
one_hot_encoding = self.one_hot_encoding()
354390
max_de = self.max_de()
355391
weight_trimming_mean_ratio = self.weight_trimming_mean_ratio()
392+
logistic_regression_kwargs = self.logistic_regression_kwargs()
356393
sample_cls, sample_package_name, sample_package_version = (
357394
balance_sample_cls,
358395
__package__,
@@ -370,6 +407,7 @@ def update_attributes_for_main_used_by_adjust(self) -> None:
370407
self._lambda_max,
371408
self._num_lambdas,
372409
self._weight_trimming_mean_ratio,
410+
self._logistic_regression_kwargs,
373411
self._sample_cls,
374412
self._sample_package_name,
375413
self._sample_package_version,
@@ -383,6 +421,7 @@ def update_attributes_for_main_used_by_adjust(self) -> None:
383421
lambda_max,
384422
num_lambdas,
385423
weight_trimming_mean_ratio,
424+
logistic_regression_kwargs,
386425
sample_cls,
387426
sample_package_name,
388427
sample_package_version,
@@ -400,6 +439,7 @@ def main(self) -> None:
400439
lambda_max,
401440
num_lambdas,
402441
weight_trimming_mean_ratio,
442+
logistic_regression_kwargs,
403443
sample_cls,
404444
sample_package_name,
405445
sample_package_version,
@@ -413,6 +453,7 @@ def main(self) -> None:
413453
self._lambda_max,
414454
self._num_lambdas,
415455
self._weight_trimming_mean_ratio,
456+
self._logistic_regression_kwargs,
416457
self._sample_cls,
417458
self._sample_package_name,
418459
self._sample_package_version,
@@ -434,6 +475,7 @@ def main(self) -> None:
434475
"lambda_max",
435476
"num_lambdas",
436477
"weight_trimming_mean_ratio",
478+
"logistic_regression_kwargs",
437479
"sample_cls",
438480
"sample_package_name",
439481
"sample_package_version",
@@ -448,6 +490,7 @@ def main(self) -> None:
448490
lambda_max,
449491
num_lambdas,
450492
weight_trimming_mean_ratio,
493+
logistic_regression_kwargs,
451494
sample_cls,
452495
sample_package_name,
453496
sample_package_version,
@@ -476,6 +519,7 @@ def main(self) -> None:
476519
lambda_max,
477520
num_lambdas,
478521
weight_trimming_mean_ratio,
522+
logistic_regression_kwargs,
479523
sample_cls,
480524
sample_package_name,
481525
)
@@ -501,6 +545,7 @@ def main(self) -> None:
501545
lambda_max,
502546
num_lambdas,
503547
weight_trimming_mean_ratio,
548+
logistic_regression_kwargs,
504549
sample_cls,
505550
sample_package_name,
506551
)
@@ -686,6 +731,16 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser:
686731
"If not supplied it defaults to 250."
687732
),
688733
)
734+
parser.add_argument(
735+
"--ipw_logistic_regression_kwargs",
736+
type=str,
737+
required=False,
738+
default=None,
739+
help=(
740+
"JSON object string with additional keyword arguments passed to "
741+
"sklearn.linear_model.LogisticRegression when method is ipw."
742+
),
743+
)
689744
parser.add_argument(
690745
"--weight_trimming_mean_ratio",
691746
type=_float_or_none,
@@ -710,15 +765,13 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser:
710765
# TODO: Ideally we would like transformations argument to be able to get three types of values: None (for no transformations),
711766
# "default" for default transformations or a dictionary of transformations.
712767
# However, as a first step I added the option for "default" (which is also the default) and None (for no transformations).
713-
(
714-
parser.add_argument(
715-
"--transformations",
716-
default="default",
717-
required=False,
718-
help=(
719-
"Define the transformations for the covariates. Can be set to None for no transformations or"
720-
"'default' for default transformations."
721-
),
768+
parser.add_argument(
769+
"--transformations",
770+
default="default",
771+
required=False,
772+
help=(
773+
"Define the transformations for the covariates. Can be set to None for no transformations or"
774+
"'default' for default transformations."
722775
),
723776
)
724777
# TODO: we currently support only the option of a string formula (or None), not a list of formulas.

0 commit comments

Comments
 (0)