We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f54fa02 commit 11cdb3eCopy full SHA for 11cdb3e
econml/dml/_rlearner.py
@@ -50,12 +50,12 @@ def __init__(self, model_y: ModelSelector, model_t: ModelSelector):
50
self._model_y = model_y
51
self._model_t = model_t
52
53
- def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
+ def train(self, is_selecting, folds, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, **fit_params):
54
assert Z is None, "Cannot accept instrument!"
55
self._model_t.train(is_selecting, folds, X, W, T, **
56
- filter_none_kwargs(sample_weight=sample_weight, groups=groups))
+ filter_none_kwargs(sample_weight=sample_weight, groups=groups, **fit_params))
57
self._model_y.train(is_selecting, folds, X, W, Y, **
58
59
return self
60
61
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None,
econml/dml/dml.py
@@ -97,7 +97,7 @@ def __init__(self, model: SingleModelSelector, discrete_target):
97
self._model = clone(model, safe=False)
98
self._discrete_target = discrete_target
99
100
- def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=None):
+ def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=None, **fit_params):
101
if self._discrete_target:
102
# In this case, the Target is the one-hot-encoding of the treatment variable
103
# We need to go back to the label representation of the one-hot so as to call
@@ -108,7 +108,7 @@ def train(self, is_selecting, folds, X, W, Target, sample_weight=None, groups=No
108
Target = inverse_onehot(Target)
109
110
self._model.train(is_selecting, folds, _combine(X, W, Target.shape[0]), Target,
111
- **filter_none_kwargs(groups=groups, sample_weight=sample_weight))
+ **filter_none_kwargs(groups=groups, sample_weight=sample_weight, **fit_params))
112
113
114
@property
0 commit comments