Skip to content

Commit a790fb2

Browse files
committed
Fix StatsModels2SLS clustered SE bug: return groups from _check_input to handle groups=None
Signed-off-by: Mikayel Sukiasyan <[email protected]>
1 parent b2193eb commit a790fb2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

econml/sklearn_extensions/linear_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,7 +2133,7 @@ def _check_input(self, Z, T, y, sample_weight, groups=None):
21332133
weighted_y = y * np.sqrt(sample_weight)
21342134
else:
21352135
weighted_y = y * np.sqrt(sample_weight).reshape(-1, 1)
2136-
return weighted_Z, weighted_T, weighted_y
2136+
return weighted_Z, weighted_T, weighted_y, groups
21372137

21382138
def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None, groups=None):
21392139
"""
@@ -2166,7 +2166,7 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None, gr
21662166
assert freq_weight is None, "freq_weight is not supported yet for this class!"
21672167
assert sample_var is None, "sample_var is not supported yet for this class!"
21682168

2169-
Z, T, y = self._check_input(Z, T, y, sample_weight, groups)
2169+
Z, T, y, groups = self._check_input(Z, T, y, sample_weight, groups)
21702170

21712171
self._n_out = 0 if y.ndim < 2 else y.shape[1]
21722172

0 commit comments

Comments
 (0)