We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b2193eb commit a790fb2Copy full SHA for a790fb2
econml/sklearn_extensions/linear_model.py
@@ -2133,7 +2133,7 @@ def _check_input(self, Z, T, y, sample_weight, groups=None):
2133
weighted_y = y * np.sqrt(sample_weight)
2134
else:
2135
weighted_y = y * np.sqrt(sample_weight).reshape(-1, 1)
2136
- return weighted_Z, weighted_T, weighted_y
+ return weighted_Z, weighted_T, weighted_y, groups
2137
2138
def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None, groups=None):
2139
"""
@@ -2166,7 +2166,7 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None, gr
2166
assert freq_weight is None, "freq_weight is not supported yet for this class!"
2167
assert sample_var is None, "sample_var is not supported yet for this class!"
2168
2169
- Z, T, y = self._check_input(Z, T, y, sample_weight, groups)
+ Z, T, y, groups = self._check_input(Z, T, y, sample_weight, groups)
2170
2171
self._n_out = 0 if y.ndim < 2 else y.shape[1]
2172
0 commit comments