Skip to content

Commit 4a5f482

Browse files
committed
ENH: passing fit_params through all train functions
Signed-off-by: Atharva Kelkar <[email protected]>
1 parent 11cdb3e commit 4a5f482

File tree

4 files changed

+27
-27
lines changed

4 files changed

+27
-27
lines changed

econml/dr/_drlearner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self,
7070
def _combine(self, X, W):
7171
return np.hstack([arr for arr in [X, W] if arr is not None])
7272

73-
def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None, groups=None):
73+
def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None, groups=None, **fit_params):
7474
if Y.ndim != 1 and (Y.ndim != 2 or Y.shape[1] != 1):
7575
raise ValueError("The outcome matrix must be of shape ({0}, ) or ({0}, 1), "
7676
"instead got {1}.".format(len(X), Y.shape))
@@ -80,7 +80,7 @@ def train(self, is_selecting, folds, Y, T, X=None, W=None, *, sample_weight=None
8080
raise AttributeError("Provided crossfit folds contain training splits that " +
8181
"don't contain all treatments")
8282
XW = self._combine(X, W)
83-
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight)
83+
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight, **fit_params)
8484

8585
self._model_propensity.train(is_selecting, folds, XW, inverse_onehot(T), groups=groups, **filtered_kwargs)
8686
self._model_regression.train(is_selecting, folds, np.hstack([XW, T]), Y, groups=groups, **filtered_kwargs)

econml/iv/dml/_dml.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ def __init__(self,
5353
else:
5454
self._model_z_xw = model_z
5555

56-
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
57-
self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups)
58-
self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups)
56+
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params):
57+
self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params)
58+
self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups, **fit_params)
5959
if self._projection:
6060
# concat W and Z
6161
WZ = _combine(W, Z, Y.shape[0])
6262
self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T,
63-
sample_weight=sample_weight, groups=groups)
63+
sample_weight=sample_weight, groups=groups, **fit_params)
6464
else:
65-
self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups)
65+
self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params)
6666
return self
6767

6868
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
@@ -720,15 +720,15 @@ def __init__(self, model_y_xw: ModelSelector, model_t_xw: ModelSelector, model_t
720720
self._model_t_xw = model_t_xw
721721
self._model_t_xwz = model_t_xwz
722722

723-
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
723+
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params):
724724
self._model_y_xw.train(is_selecting, folds, X, W, Y, **
725-
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
725+
filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params))
726726
self._model_t_xw.train(is_selecting, folds, X, W, T, **
727-
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
727+
filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params))
728728
# concat W and Z
729729
WZ = _combine(W, Z, Y.shape[0])
730730
self._model_t_xwz.train(is_selecting, folds, X, WZ, T,
731-
**filter_none_kwargs(sample_weight=sample_weight, groups=groups))
731+
**filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params))
732732
return self
733733

734734
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):

econml/iv/dr/_dr.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,28 +56,28 @@ def __init__(self, *, prel_model_effect, model_y_xw, model_t_xw, model_z,
5656
else:
5757
self._model_z_xw = model_z
5858

59-
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
59+
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params):
6060
# T and Z only allow single continuous or binary, keep the shape of (n,) for continuous and (n,1) for binary
6161
T = T.ravel() if not self._discrete_treatment else T
6262
Z = Z.ravel() if not self._discrete_instrument else Z
6363

64-
self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups)
65-
self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups)
64+
self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params)
65+
self._model_t_xw.train(is_selecting, folds, X=X, W=W, Target=T, sample_weight=sample_weight, groups=groups, **fit_params)
6666

6767
if self._projection:
6868
WZ = _combine(W, Z, Y.shape[0])
6969
self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T,
70-
sample_weight=sample_weight, groups=groups)
70+
sample_weight=sample_weight, groups=groups, **fit_params)
7171
else:
72-
self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups)
72+
self._model_z_xw.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params)
7373

7474
# TODO: prel_model_effect could allow sample_var and freq_weight?
7575
if self._discrete_instrument:
7676
Z = inverse_onehot(Z)
7777
if self._discrete_treatment:
7878
T = inverse_onehot(T)
7979
self._prel_model_effect.fit(Y, T, Z=Z, X=X,
80-
W=W, sample_weight=sample_weight, groups=groups)
80+
W=W, sample_weight=sample_weight, groups=groups, **fit_params)
8181
return self
8282

8383
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
@@ -215,11 +215,11 @@ def _get_target(self, T_res, Z_res, T, Z):
215215

216216
def train(self, is_selecting, folds,
217217
prel_theta, Y_res, T_res, Z_res,
218-
Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
218+
Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params):
219219
# T and Z only allow single continuous or binary, keep the shape of (n,) for continuous and (n,1) for binary
220220
target = self._get_target(T_res, Z_res, T, Z)
221221
self._model_tz_xw.train(is_selecting, folds, X=X, W=W, Target=target,
222-
sample_weight=sample_weight, groups=groups)
222+
sample_weight=sample_weight, groups=groups, **fit_params)
223223

224224
return self
225225

@@ -2386,16 +2386,16 @@ def __init__(self,
23862386
self._dummy_z = dummy_z
23872387
self._prel_model_effect = prel_model_effect
23882388

2389-
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
2390-
self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups)
2389+
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params):
2390+
self._model_y_xw.train(is_selecting, folds, X=X, W=W, Target=Y, sample_weight=sample_weight, groups=groups, **fit_params)
23912391
# concat W and Z
23922392
WZ = _combine(W, Z, Y.shape[0])
2393-
self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups)
2394-
self._dummy_z.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups)
2393+
self._model_t_xwz.train(is_selecting, folds, X=X, W=WZ, Target=T, sample_weight=sample_weight, groups=groups, **fit_params)
2394+
self._dummy_z.train(is_selecting, folds, X=X, W=W, Target=Z, sample_weight=sample_weight, groups=groups, **fit_params)
23952395
# we need to undo the one-hot encoding for calling effect,
23962396
# since it expects raw values
23972397
self._prel_model_effect.fit(Y, inverse_onehot(T), Z=inverse_onehot(Z), X=X, W=W,
2398-
sample_weight=sample_weight, groups=groups)
2398+
sample_weight=sample_weight, groups=groups, **fit_params)
23992399
return self
24002400

24012401
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):

econml/panel/dml/_dml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, model_y, model_t, n_periods):
4040
self._model_t = model_t
4141
self.n_periods = n_periods
4242

43-
def train(self, is_selecting, folds, Y, T, X=None, W=None, sample_weight=None, groups=None):
43+
def train(self, is_selecting, folds, Y, T, X=None, W=None, sample_weight=None, groups=None, **fit_params):
4444
"""Fit a series of nuisance models for each period or period pairs."""
4545
assert Y.shape[0] % self.n_periods == 0, \
4646
"Length of training data should be an integer multiple of time periods."
@@ -87,13 +87,13 @@ def _translate_inds(t, inds):
8787
self._index_or_None(X, period_filters[t]),
8888
self._index_or_None(
8989
W, period_filters[t]),
90-
Y[period_filters[self.n_periods - 1]])
90+
Y[period_filters[self.n_periods - 1]], **fit_params)
9191
for j in np.arange(t, self.n_periods):
9292
self._model_t_trained[j][t].train(
9393
is_selecting, translated_folds,
9494
self._index_or_None(X, period_filters[t]),
9595
self._index_or_None(W, period_filters[t]),
96-
T[period_filters[j]])
96+
T[period_filters[j]], **fit_params)
9797
return self
9898

9999
def predict(self, Y, T, X=None, W=None, sample_weight=None, groups=None):

0 commit comments

Comments
 (0)