Skip to content

Commit fcbaf38

Browse files
committed
ENH: Pass fit_params through train methods
1 parent f54fa02 commit fcbaf38

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

econml/dml/_rlearner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ def __init__(self, model_y: ModelSelector, model_t: ModelSelector):
5050
self._model_y = model_y
5151
self._model_t = model_t
5252

53-
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
53+
def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params):
5454
assert Z is None, "Cannot accept instrument!"
5555
self._model_t.train(is_selecting, folds, X, W, T, **
56-
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
56+
filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params))
5757
self._model_y.train(is_selecting, folds, X, W, Y, **
58-
filter_none_kwargs(sample_weight=sample_weight, groups=groups))
58+
filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params))
5959
return self
6060

6161
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None,

econml/dml/dml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, model: SingleModelSelector, discrete_target):
9797
self._model = clone(model, safe=False)
9898
self._discrete_target = discrete_target
9999

100-
def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=None):
100+
def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=None, **fit_params):
101101
if self._discrete_target:
102102
# In this case, the Target is the one-hot-encoding of the treatment variable
103103
# We need to go back to the label representation of the one-hot so as to call
@@ -108,7 +108,7 @@ def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=No
108108
Target = inverse_onehot(Target)
109109

110110
self._model.train(is_selecting, folds, _combine(X, W, Target.shape[0]), Target,
111-
**filter_none_kwargs(groups=groups, sample_weight=sample_weight))
111+
**filter_none_kwargs(groups=groups, sample_weight=sample_weight, **fit_params))
112112
return self
113113

114114
@property

0 commit comments

Comments
 (0)