Skip to content

Commit b2193eb

Browse files
committed
Add clustered standard errors support to linear models and OrthoIV
- Implement clustered variance calculation in StatsModelsLinearRegression and StatsModels2SLS - Add cov_type='clustered' parameter to OrthoIV estimator - Add tests validating against statsmodels implementation Signed-off-by: Mikayel Sukiasyan <[email protected]>
1 parent f54fa02 commit b2193eb

File tree

4 files changed

+316
-23
lines changed

4 files changed

+316
-23
lines changed

econml/dml/_rlearner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def __init__(self, model_final):
9898
def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None,
9999
sample_weight=None, freq_weight=None, sample_var=None, groups=None):
100100
Y_res, T_res = nuisances
101-
self._model_final.fit(X, T, T_res, Y_res, sample_weight=sample_weight,
102-
freq_weight=freq_weight, sample_var=sample_var)
101+
self._model_final.fit(X, T, T_res, Y_res, **(filter_none_kwargs(sample_weight=sample_weight,
102+
freq_weight=freq_weight, sample_var=sample_var, groups=groups)))
103103
return self
104104

105105
def predict(self, X=None):

econml/iv/dml/_dml.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None,
157157
XT_res = self._combine(X, T_res)
158158
XZ_res = self._combine(X, Z_res)
159159
filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight,
160-
freq_weight=freq_weight, sample_var=sample_var)
160+
freq_weight=freq_weight, sample_var=sample_var, groups=groups)
161161

162162
self._model_final.fit(XZ_res, XT_res, Y_res, **filtered_kwargs)
163163

@@ -376,14 +376,16 @@ def __init__(self, *,
376376
mc_iters=None,
377377
mc_agg='mean',
378378
random_state=None,
379-
allow_missing=False):
379+
allow_missing=False,
380+
cov_type="HC0"):
380381
self.model_y_xw = clone(model_y_xw, safe=False)
381382
self.model_t_xw = clone(model_t_xw, safe=False)
382383
self.model_t_xwz = clone(model_t_xwz, safe=False)
383384
self.model_z_xw = clone(model_z_xw, safe=False)
384385
self.projection = projection
385386
self.featurizer = clone(featurizer, safe=False)
386387
self.fit_cate_intercept = fit_cate_intercept
388+
self.cov_type = cov_type
387389

388390
super().__init__(discrete_outcome=discrete_outcome,
389391
discrete_instrument=discrete_instrument,
@@ -403,7 +405,7 @@ def _gen_featurizer(self):
403405
return clone(self.featurizer, safe=False)
404406

405407
def _gen_model_final(self):
406-
return StatsModels2SLS(cov_type="HC0")
408+
return StatsModels2SLS(cov_type=self.cov_type)
407409

408410
def _gen_ortho_learner_model_final(self):
409411
return _OrthoIVModelFinal(self._gen_model_final(), self._gen_featurizer(), self.fit_cate_intercept)

econml/sklearn_extensions/linear_model.py

Lines changed: 129 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,7 +1693,7 @@ class StatsModelsLinearRegression(_StatsModelsWrapper):
16931693
fit_intercept : bool, default True
16941694
Whether to fit an intercept in this model
16951695
cov_type : string, default "HC0"
1696-
The covariance approach to use. Supported values are "HCO", "HC1", and "nonrobust".
1696+
The covariance approach to use. Supported values are "HC0", "HC1", "nonrobust", and "clustered".
16971697
enable_federation : bool, default False
16981698
Whether to enable federation (aggregating this model's results with other models in a distributed setting).
16991699
This requires additional memory proportional to the number of columns in X to the fourth power.
@@ -1704,10 +1704,10 @@ def __init__(self, fit_intercept=True, cov_type="HC0", *, enable_federation=Fals
17041704
self.fit_intercept = fit_intercept
17051705
self.enable_federation = enable_federation
17061706

1707-
def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
1707+
def _check_input(self, X, y, sample_weight, freq_weight, sample_var, groups=None):
17081708
"""Check dimensions and other assertions."""
1709-
X, y, sample_weight, freq_weight, sample_var = check_input_arrays(
1710-
X, y, sample_weight, freq_weight, sample_var, dtype='numeric')
1709+
X, y, sample_weight, freq_weight, sample_var, groups = check_input_arrays(
1710+
X, y, sample_weight, freq_weight, sample_var, groups, dtype='numeric')
17111711
if X is None:
17121712
X = np.empty((y.shape[0], 0))
17131713
if self.fit_intercept:
@@ -1720,6 +1720,8 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
17201720
freq_weight = np.ones(y.shape[0])
17211721
if sample_var is None:
17221722
sample_var = np.zeros(y.shape)
1723+
if groups is None:
1724+
groups = np.arange(y.shape[0])
17231725

17241726
# check freq_weight should be integer and should be accompanied by sample_var
17251727
if np.any(np.not_equal(np.mod(freq_weight, 1), 0)):
@@ -1753,7 +1755,7 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
17531755

17541756
# check array shape
17551757
assert (X.shape[0] == y.shape[0] == sample_weight.shape[0] ==
1756-
freq_weight.shape[0] == sample_var.shape[0]), "Input lengths not compatible!"
1758+
freq_weight.shape[0] == sample_var.shape[0] == groups.shape[0]), "Input lengths not compatible!"
17571759
if y.ndim >= 2:
17581760
assert (y.ndim == sample_var.ndim and
17591761
y.shape[1] == sample_var.shape[1]), "Input shapes not compatible: {}, {}!".format(
@@ -1767,9 +1769,9 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
17671769
else:
17681770
weighted_y = y * np.sqrt(sample_weight).reshape(-1, 1)
17691771
sample_var = sample_var * (sample_weight.reshape(-1, 1))
1770-
return weighted_X, weighted_y, freq_weight, sample_var
1772+
return weighted_X, weighted_y, freq_weight, sample_var, groups
17711773

1772-
def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
1774+
def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None, groups=None):
17731775
"""
17741776
Fits the model.
17751777
@@ -1788,13 +1790,15 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
17881790
sample_var : {(N,), (N, p)} nd array_like or None
17891791
Variance of the outcome(s) of the original freq_weight[i] observations that were used to
17901792
compute the mean outcome represented by observation i.
1793+
groups : (N,) array_like or None
1794+
Group labels for clustered standard errors.
17911795
17921796
Returns
17931797
-------
17941798
self : StatsModelsLinearRegression
17951799
"""
17961800
# TODO: Add other types of covariance estimation (e.g. Newey-West (HAC), HC2, HC3)
1797-
X, y, freq_weight, sample_var = self._check_input(X, y, sample_weight, freq_weight, sample_var)
1801+
X, y, freq_weight, sample_var, groups = self._check_input(X, y, sample_weight, freq_weight, sample_var, groups)
17981802

17991803
WX = X * np.sqrt(freq_weight).reshape(-1, 1)
18001804

@@ -1840,6 +1844,8 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
18401844
self.XXXy = np.einsum('nx,ny->yx', WX, wy)
18411845
self.XXXX = np.einsum('nw,nx->wx', WX, WX)
18421846
self.sample_var = np.average(sv, weights=freq_weight, axis=0) * n_obs
1847+
elif self.cov_type == 'clustered':
1848+
raise AttributeError("Clustered standard errors are not supported with federation enabled.")
18431849

18441850
sigma_inv = np.linalg.pinv(self.XX)
18451851

@@ -1871,8 +1877,10 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
18711877
for j in range(self._n_out):
18721878
weighted_sigma = np.matmul(WX.T, WX * var_i[:, [j]])
18731879
self._var.append(correction * np.matmul(sigma_inv, np.matmul(weighted_sigma, sigma_inv)))
1880+
elif (self.cov_type == 'clustered'):
1881+
self._var = self._compute_clustered_variance_linear(WX, y - np.matmul(X, param), sigma_inv, groups)
18741882
else:
1875-
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1.")
1883+
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1, clustered.")
18761884

18771885
self._param_var = np.array(self._var)
18781886

@@ -1937,7 +1945,6 @@ def aggregate(models: List[StatsModelsLinearRegression]):
19371945
agg_model._var = correction * np.matmul(sigma_inv, np.matmul(weighted_sigma.squeeze(0), sigma_inv))
19381946
else:
19391947
agg_model._var = [correction * np.matmul(sigma_inv, np.matmul(ws, sigma_inv)) for ws in weighted_sigma]
1940-
19411948
else:
19421949
assert agg_model.cov_type == 'nonrobust' or agg_model.cov_type is None
19431950
sigma = XXyy - 2 * np.einsum('yx,xy->y', XXXy, param) + np.einsum('wx,wy,xy->y', XXXX, param, param)
@@ -1954,6 +1961,54 @@ def aggregate(models: List[StatsModelsLinearRegression]):
19541961

19551962
return agg_model
19561963

1964+
def _compute_clustered_variance_linear(self, WX, eps_i, sigma_inv, groups):
1965+
"""
1966+
Compute clustered standard errors for linear regression.
1967+
1968+
Parameters
1969+
----------
1970+
WX : array_like
1971+
Weighted design matrix
1972+
eps_i : array_like
1973+
Residuals
1974+
sigma_inv : array_like
1975+
Inverse of X.T @ X
1976+
groups : array_like
1977+
Group labels for clustering
1978+
1979+
Returns
1980+
-------
1981+
var : array_like or list
1982+
Clustered variance matrix
1983+
"""
1984+
n, k = WX.shape
1985+
group_ids, inverse_idx = np.unique(groups, return_inverse=True)
1986+
n_groups = len(group_ids)
1987+
1988+
# Group correction factor
1989+
group_correction = (n_groups / (n_groups - 1))
1990+
1991+
if eps_i.ndim < 2:
1992+
# Single outcome case
1993+
WX_e = WX * eps_i.reshape(-1, 1)
1994+
group_sums = np.zeros((n_groups, k))
1995+
np.add.at(group_sums, inverse_idx, WX_e)
1996+
s = group_sums.T @ group_sums
1997+
1998+
return group_correction * np.matmul(sigma_inv, np.matmul(s, sigma_inv))
1999+
else:
2000+
# Multiple outcome case
2001+
var_list = []
2002+
for j in range(eps_i.shape[1]):
2003+
WX_e = WX * eps_i[:, [j]]
2004+
group_sums = np.zeros((n_groups, k))
2005+
np.add.at(group_sums, inverse_idx, WX_e)
2006+
s = group_sums.T @ group_sums
2007+
2008+
var_list.append(group_correction * np.matmul(sigma_inv, np.matmul(s, sigma_inv)))
2009+
2010+
return var_list
2011+
19572012

19582013
class StatsModelsRLM(_StatsModelsWrapper):
19592014
"""
@@ -2040,23 +2095,28 @@ class StatsModels2SLS(_StatsModelsWrapper):
20402095
20412096
Parameters
20422097
----------
2043-
cov_type : {'HC0', 'HC1', 'nonrobust', or None}, default 'HC0'
2044-
Indicates how the covariance matrix is estimated.
2098+
cov_type : {'HC0', 'HC1', 'nonrobust', 'clustered', or None}, default 'HC0'
2099+
Indicates how the covariance matrix is estimated. 'clustered' requires groups to be provided in fit().
20452100
"""
20462101

20472102
def __init__(self, cov_type="HC0"):
20482103
self.fit_intercept = False
20492104
self.cov_type = cov_type
20502105
return
20512106

2052-
def _check_input(self, Z, T, y, sample_weight):
2107+
def _check_input(self, Z, T, y, sample_weight, groups=None):
20532108
"""Check dimensions and other assertions."""
20542109
# set default values for None
20552110
if sample_weight is None:
20562111
sample_weight = np.ones(y.shape[0])
2112+
if groups is None:
2113+
groups = np.arange(y.shape[0])
2114+
else:
2115+
groups = np.asarray(groups)
20572116

20582117
# check array shape
2059-
assert (T.shape[0] == Z.shape[0] == y.shape[0] == sample_weight.shape[0]), "Input lengths not compatible!"
2118+
assert (T.shape[0] == Z.shape[0] == y.shape[0] == sample_weight.shape[0] == groups.shape[0]), \
2119+
"Input lengths not compatible!"
20602120

20612121
# check dimension of instruments is more than dimension of treatments
20622122
if Z.shape[1] < T.shape[1]:
@@ -2075,7 +2135,7 @@ def _check_input(self, Z, T, y, sample_weight):
20752135
weighted_y = y * np.sqrt(sample_weight).reshape(-1, 1)
20762136
return weighted_Z, weighted_T, weighted_y
20772137

2078-
def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
2138+
def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None, groups=None):
20792139
"""
20802140
Fits the model.
20812141
@@ -2096,7 +2156,8 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
20962156
sample_var : {(N,), (N, p)} nd array_like or None
20972157
Variance of the outcome(s) of the original freq_weight[i] observations that were used to
20982158
compute the mean outcome represented by observation i.
2099-
2159+
groups : (N,) array_like or None
2160+
Group labels for clustered standard errors. Required when cov_type='clustered'.
21002161
21012162
Returns
21022163
-------
@@ -2105,7 +2166,7 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
21052166
assert freq_weight is None, "freq_weight is not supported yet for this class!"
21062167
assert sample_var is None, "sample_var is not supported yet for this class!"
21072168

2108-
Z, T, y = self._check_input(Z, T, y, sample_weight)
2169+
Z, T, y = self._check_input(Z, T, y, sample_weight, groups)
21092170

21102171
self._n_out = 0 if y.ndim < 2 else y.shape[1]
21112172

@@ -2164,8 +2225,58 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
21642225
weighted_sigma = np.matmul(that.T, that * var_i[:, [j]])
21652226
self._var.append(correction * np.matmul(thatT_that_inv,
21662227
np.matmul(weighted_sigma, thatT_that_inv)))
2228+
elif (self.cov_type == 'clustered'):
2229+
self._var = self._compute_clustered_variance(that, y - np.dot(T, param), thatT_that_inv, groups)
21672230
else:
2168-
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1.")
2231+
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1, clustered.")
21692232

21702233
self._param_var = np.array(self._var)
21712234
return self
2235+
2236+
def _compute_clustered_variance(self, that, eps_i, thatT_that_inv, groups):
2237+
"""
2238+
Compute clustered standard errors.
2239+
2240+
Parameters
2241+
----------
2242+
that : array_like
2243+
Fitted values from first stage
2244+
eps_i : array_like
2245+
Residuals
2246+
thatT_that_inv : array_like
2247+
Inverse of that.T @ that
2248+
groups : array_like
2249+
Group labels for clustering
2250+
2251+
Returns
2252+
-------
2253+
var : array_like or list
2254+
Clustered variance matrix
2255+
"""
2256+
n, k = that.shape
2257+
group_ids, inverse_idx = np.unique(groups, return_inverse=True)
2258+
n_groups = len(group_ids)
2259+
2260+
# Group correction factor
2261+
group_correction = (n_groups / (n_groups - 1))
2262+
2263+
if eps_i.ndim < 2:
2264+
# Single outcome case
2265+
that_e = that * eps_i.reshape(-1, 1)
2266+
group_sums = np.zeros((n_groups, k))
2267+
np.add.at(group_sums, inverse_idx, that_e)
2268+
s = group_sums.T @ group_sums
2269+
2270+
return group_correction * np.matmul(thatT_that_inv, np.matmul(s, thatT_that_inv))
2271+
else:
2272+
# Multiple outcome case
2273+
var_list = []
2274+
for j in range(eps_i.shape[1]):
2275+
that_e = that * eps_i[:, [j]]
2276+
group_sums = np.zeros((n_groups, k))
2277+
np.add.at(group_sums, inverse_idx, that_e)
2278+
s = group_sums.T @ group_sums
2279+
2280+
var_list.append(group_correction * np.matmul(thatT_that_inv, np.matmul(s, thatT_that_inv)))
2281+
2282+
return var_list

0 commit comments

Comments
 (0)