88from __future__ import absolute_import , division , print_function , unicode_literals
99
1010import inspect
11+ import json
1112import logging
1213
1314from argparse import ArgumentParser , Namespace
1415from pathlib import Path
1516
16- from typing import Dict , List , Optional , Tuple , Type , Union
17+ from typing import Any , Dict , List , Optional , Tuple , Type , Union
1718
1819import balance
1920
@@ -40,10 +41,25 @@ def __init__(self, args) -> None:
4041 self ._lambda_max ,
4142 self ._num_lambdas ,
4243 self ._weight_trimming_mean_ratio ,
44+ self ._logistic_regression_kwargs ,
4345 self ._sample_cls ,
4446 self ._sample_package_name ,
4547 self ._sample_package_version ,
46- ) = (None , None , None , None , None , None , None , None , None , None , None , None )
48+ ) = (
49+ None ,
50+ None ,
51+ None ,
52+ None ,
53+ None ,
54+ None ,
55+ None ,
56+ None ,
57+ None ,
58+ None ,
59+ None ,
60+ None ,
61+ None ,
62+ )
4763
4864 def check_input_columns (self , columns : Union [List [str ], pd .Index ]) -> None :
4965 needed_columns = []
@@ -133,6 +149,24 @@ def standardize_types(self) -> bool:
133149 def weight_trimming_mean_ratio (self ) -> float :
134150 return self .args .weight_trimming_mean_ratio
135151
152+ def logistic_regression_kwargs (self ) -> Optional [Dict [str , Any ]]:
153+ raw_kwargs = self .args .ipw_logistic_regression_kwargs
154+ if raw_kwargs is None :
155+ return None
156+ if isinstance (raw_kwargs , dict ):
157+ return raw_kwargs
158+ try :
159+ parsed = json .loads (raw_kwargs )
160+ except json .JSONDecodeError as exc :
161+ raise ValueError (
162+ "--ipw_logistic_regression_kwargs must be a JSON object string"
163+ ) from exc
164+ if not isinstance (parsed , dict ):
165+ raise ValueError (
166+ "--ipw_logistic_regression_kwargs must decode to a JSON object"
167+ )
168+ return parsed
169+
136170 def split_sample (self , df : pd .DataFrame ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
137171 in_sample = df [self .sample_column ()] == 1
138172 sample_df = df [in_sample ]
@@ -152,6 +186,7 @@ def process_batch(
152186 lambda_max : Optional [float ] = 10 ,
153187 num_lambdas : Optional [int ] = 250 ,
154188 weight_trimming_mean_ratio : float = 20 ,
189+ logistic_regression_kwargs : Optional [Dict [str , Any ]] = None ,
155190 sample_cls : Type [balance_sample_cls ] = balance_sample_cls ,
156191 sample_package_name : str = __package__ ,
157192 ) -> Dict [str , pd .DataFrame ]:
@@ -215,6 +250,7 @@ def process_batch(
215250 lambda_max = lambda_max ,
216251 num_lambdas = num_lambdas ,
217252 weight_trimming_mean_ratio = weight_trimming_mean_ratio ,
253+ logistic_regression_kwargs = logistic_regression_kwargs ,
218254 )
219255 logger .info ("Succeeded with adjusting sample to target" )
220256 logger .info ("%s adjusted object: %s" % (sample_package_name , str (adjusted )))
@@ -353,6 +389,7 @@ def update_attributes_for_main_used_by_adjust(self) -> None:
353389 one_hot_encoding = self .one_hot_encoding ()
354390 max_de = self .max_de ()
355391 weight_trimming_mean_ratio = self .weight_trimming_mean_ratio ()
392+ logistic_regression_kwargs = self .logistic_regression_kwargs ()
356393 sample_cls , sample_package_name , sample_package_version = (
357394 balance_sample_cls ,
358395 __package__ ,
@@ -370,6 +407,7 @@ def update_attributes_for_main_used_by_adjust(self) -> None:
370407 self ._lambda_max ,
371408 self ._num_lambdas ,
372409 self ._weight_trimming_mean_ratio ,
410+ self ._logistic_regression_kwargs ,
373411 self ._sample_cls ,
374412 self ._sample_package_name ,
375413 self ._sample_package_version ,
@@ -383,6 +421,7 @@ def update_attributes_for_main_used_by_adjust(self) -> None:
383421 lambda_max ,
384422 num_lambdas ,
385423 weight_trimming_mean_ratio ,
424+ logistic_regression_kwargs ,
386425 sample_cls ,
387426 sample_package_name ,
388427 sample_package_version ,
@@ -400,6 +439,7 @@ def main(self) -> None:
400439 lambda_max ,
401440 num_lambdas ,
402441 weight_trimming_mean_ratio ,
442+ logistic_regression_kwargs ,
403443 sample_cls ,
404444 sample_package_name ,
405445 sample_package_version ,
@@ -413,6 +453,7 @@ def main(self) -> None:
413453 self ._lambda_max ,
414454 self ._num_lambdas ,
415455 self ._weight_trimming_mean_ratio ,
456+ self ._logistic_regression_kwargs ,
416457 self ._sample_cls ,
417458 self ._sample_package_name ,
418459 self ._sample_package_version ,
@@ -434,6 +475,7 @@ def main(self) -> None:
434475 "lambda_max" ,
435476 "num_lambdas" ,
436477 "weight_trimming_mean_ratio" ,
478+ "logistic_regression_kwargs" ,
437479 "sample_cls" ,
438480 "sample_package_name" ,
439481 "sample_package_version" ,
@@ -448,6 +490,7 @@ def main(self) -> None:
448490 lambda_max ,
449491 num_lambdas ,
450492 weight_trimming_mean_ratio ,
493+ logistic_regression_kwargs ,
451494 sample_cls ,
452495 sample_package_name ,
453496 sample_package_version ,
@@ -476,6 +519,7 @@ def main(self) -> None:
476519 lambda_max ,
477520 num_lambdas ,
478521 weight_trimming_mean_ratio ,
522+ logistic_regression_kwargs ,
479523 sample_cls ,
480524 sample_package_name ,
481525 )
@@ -501,6 +545,7 @@ def main(self) -> None:
501545 lambda_max ,
502546 num_lambdas ,
503547 weight_trimming_mean_ratio ,
548+ logistic_regression_kwargs ,
504549 sample_cls ,
505550 sample_package_name ,
506551 )
@@ -686,6 +731,16 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser:
686731 "If not supplied it defaults to 250."
687732 ),
688733 )
734+ parser .add_argument (
735+ "--ipw_logistic_regression_kwargs" ,
736+ type = str ,
737+ required = False ,
738+ default = None ,
739+ help = (
740+ "JSON object string with additional keyword arguments passed to "
741+ "sklearn.linear_model.LogisticRegression when method is ipw."
742+ ),
743+ )
689744 parser .add_argument (
690745 "--weight_trimming_mean_ratio" ,
691746 type = _float_or_none ,
@@ -710,15 +765,13 @@ def add_arguments_to_parser(parser: ArgumentParser) -> ArgumentParser:
710765 # TODO: Ideally we would like transformations argument to be able to get three types of values: None (for no transformations),
711766 # "default" for default transformations or a dictionary of transformations.
712767 # However, as a first step I added the option for "default" (which is also the default) and None (for no transformations).
713- (
714- parser .add_argument (
715- "--transformations" ,
716- default = "default" ,
717- required = False ,
718- help = (
719- "Define the transformations for the covariates. Can be set to None for no transformations or"
720- "'default' for default transformations."
721- ),
768+ parser .add_argument (
769+ "--transformations" ,
770+ default = "default" ,
771+ required = False ,
772+ help = (
773+ "Define the transformations for the covariates. Can be set to None for no transformations or"
774+ "'default' for default transformations."
722775 ),
723776 )
724777 # TODO: we currently support only the option of a string formula (or None), not a list of formulas.
0 commit comments