@@ -56,28 +56,28 @@ def __init__(self, *, prel_model_effect, model_y_xw, model_t_xw, model_z,
5656 else :
5757 self ._model_z_xw = model_z
5858
59- def train (self , is_selecting , folds , Y , T , X = None , W = None , Z = None , sample_weight = None , groups = None ):
59+ def train (self , is_selecting , folds , Y , T , X = None , W = None , Z = None , sample_weight = None , groups = None , ** fit_params ):
6060 # T and Z only allow single continuous or binary, keep the shape of (n,) for continuous and (n,1) for binary
6161 T = T .ravel () if not self ._discrete_treatment else T
6262 Z = Z .ravel () if not self ._discrete_instrument else Z
6363
64- self ._model_y_xw .train (is_selecting , folds , X = X , W = W , Target = Y , sample_weight = sample_weight , groups = groups )
65- self ._model_t_xw .train (is_selecting , folds , X = X , W = W , Target = T , sample_weight = sample_weight , groups = groups )
64+ self ._model_y_xw .train (is_selecting , folds , X = X , W = W , Target = Y , sample_weight = sample_weight , groups = groups , ** fit_params )
65+ self ._model_t_xw .train (is_selecting , folds , X = X , W = W , Target = T , sample_weight = sample_weight , groups = groups , ** fit_params )
6666
6767 if self ._projection :
6868 WZ = _combine (W , Z , Y .shape [0 ])
6969 self ._model_t_xwz .train (is_selecting , folds , X = X , W = WZ , Target = T ,
70- sample_weight = sample_weight , groups = groups )
70+ sample_weight = sample_weight , groups = groups , ** fit_params )
7171 else :
72- self ._model_z_xw .train (is_selecting , folds , X = X , W = W , Target = Z , sample_weight = sample_weight , groups = groups )
72+ self ._model_z_xw .train (is_selecting , folds , X = X , W = W , Target = Z , sample_weight = sample_weight , groups = groups , ** fit_params )
7373
7474 # TODO: prel_model_effect could allow sample_var and freq_weight?
7575 if self ._discrete_instrument :
7676 Z = inverse_onehot (Z )
7777 if self ._discrete_treatment :
7878 T = inverse_onehot (T )
7979 self ._prel_model_effect .fit (Y , T , Z = Z , X = X ,
80- W = W , sample_weight = sample_weight , groups = groups )
80+ W = W , sample_weight = sample_weight , groups = groups , ** fit_params )
8181 return self
8282
8383 def score (self , Y , T , X = None , W = None , Z = None , sample_weight = None , groups = None ):
@@ -215,11 +215,11 @@ def _get_target(self, T_res, Z_res, T, Z):
215215
216216 def train (self , is_selecting , folds ,
217217 prel_theta , Y_res , T_res , Z_res ,
218- Y , T , X = None , W = None , Z = None , sample_weight = None , groups = None ):
218+ Y , T , X = None , W = None , Z = None , sample_weight = None , groups = None , ** fit_params ):
219219 # T and Z only allow single continuous or binary, keep the shape of (n,) for continuous and (n,1) for binary
220220 target = self ._get_target (T_res , Z_res , T , Z )
221221 self ._model_tz_xw .train (is_selecting , folds , X = X , W = W , Target = target ,
222- sample_weight = sample_weight , groups = groups )
222+ sample_weight = sample_weight , groups = groups , ** fit_params )
223223
224224 return self
225225
@@ -2386,16 +2386,16 @@ def __init__(self,
23862386 self ._dummy_z = dummy_z
23872387 self ._prel_model_effect = prel_model_effect
23882388
2389- def train (self , is_selecting , folds , Y , T , X = None , W = None , Z = None , sample_weight = None , groups = None ):
2390- self ._model_y_xw .train (is_selecting , folds , X = X , W = W , Target = Y , sample_weight = sample_weight , groups = groups )
2389+ def train (self , is_selecting , folds , Y , T , X = None , W = None , Z = None , sample_weight = None , groups = None , ** fit_params ):
2390+ self ._model_y_xw .train (is_selecting , folds , X = X , W = W , Target = Y , sample_weight = sample_weight , groups = groups , ** fit_params )
23912391 # concat W and Z
23922392 WZ = _combine (W , Z , Y .shape [0 ])
2393- self ._model_t_xwz .train (is_selecting , folds , X = X , W = WZ , Target = T , sample_weight = sample_weight , groups = groups )
2394- self ._dummy_z .train (is_selecting , folds , X = X , W = W , Target = Z , sample_weight = sample_weight , groups = groups )
2393+ self ._model_t_xwz .train (is_selecting , folds , X = X , W = WZ , Target = T , sample_weight = sample_weight , groups = groups , ** fit_params )
2394+ self ._dummy_z .train (is_selecting , folds , X = X , W = W , Target = Z , sample_weight = sample_weight , groups = groups , ** fit_params )
23952395 # we need to undo the one-hot encoding for calling effect,
23962396 # since it expects raw values
23972397 self ._prel_model_effect .fit (Y , inverse_onehot (T ), Z = inverse_onehot (Z ), X = X , W = W ,
2398- sample_weight = sample_weight , groups = groups )
2398+ sample_weight = sample_weight , groups = groups , ** fit_params )
23992399 return self
24002400
24012401 def score (self , Y , T , X = None , W = None , Z = None , sample_weight = None , groups = None ):
0 commit comments