@@ -1693,7 +1693,7 @@ class StatsModelsLinearRegression(_StatsModelsWrapper):
16931693 fit_intercept : bool, default True
16941694 Whether to fit an intercept in this model
16951695 cov_type : string, default "HC0"
1696- The covariance approach to use. Supported values are "HCO ", "HC1", and "nonrobust ".
1696+ The covariance approach to use. Supported values are "HC0 ", "HC1", "nonrobust", and "clustered ".
16971697 enable_federation : bool, default False
16981698 Whether to enable federation (aggregating this model's results with other models in a distributed setting).
16991699 This requires additional memory proportional to the number of columns in X to the fourth power.
@@ -1704,10 +1704,10 @@ def __init__(self, fit_intercept=True, cov_type="HC0", *, enable_federation=Fals
17041704 self .fit_intercept = fit_intercept
17051705 self .enable_federation = enable_federation
17061706
1707- def _check_input (self , X , y , sample_weight , freq_weight , sample_var ):
1707+ def _check_input (self , X , y , sample_weight , freq_weight , sample_var , groups = None ):
17081708 """Check dimensions and other assertions."""
1709- X , y , sample_weight , freq_weight , sample_var = check_input_arrays (
1710- X , y , sample_weight , freq_weight , sample_var , dtype = 'numeric' )
1709+ X , y , sample_weight , freq_weight , sample_var , groups = check_input_arrays (
1710+ X , y , sample_weight , freq_weight , sample_var , groups , dtype = 'numeric' )
17111711 if X is None :
17121712 X = np .empty ((y .shape [0 ], 0 ))
17131713 if self .fit_intercept :
@@ -1720,6 +1720,8 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
17201720 freq_weight = np .ones (y .shape [0 ])
17211721 if sample_var is None :
17221722 sample_var = np .zeros (y .shape )
1723+ if groups is None :
1724+ groups = np .arange (y .shape [0 ])
17231725
17241726 # check freq_weight should be integer and should be accompanied by sample_var
17251727 if np .any (np .not_equal (np .mod (freq_weight , 1 ), 0 )):
@@ -1753,7 +1755,7 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
17531755
17541756 # check array shape
17551757 assert (X .shape [0 ] == y .shape [0 ] == sample_weight .shape [0 ] ==
1756- freq_weight .shape [0 ] == sample_var .shape [0 ]), "Input lengths not compatible!"
1758+ freq_weight .shape [0 ] == sample_var .shape [0 ] == groups . shape [ 0 ] ), "Input lengths not compatible!"
17571759 if y .ndim >= 2 :
17581760 assert (y .ndim == sample_var .ndim and
17591761 y .shape [1 ] == sample_var .shape [1 ]), "Input shapes not compatible: {}, {}!" .format (
@@ -1767,9 +1769,9 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var):
17671769 else :
17681770 weighted_y = y * np .sqrt (sample_weight ).reshape (- 1 , 1 )
17691771 sample_var = sample_var * (sample_weight .reshape (- 1 , 1 ))
1770- return weighted_X , weighted_y , freq_weight , sample_var
1772+ return weighted_X , weighted_y , freq_weight , sample_var , groups
17711773
1772- def fit (self , X , y , sample_weight = None , freq_weight = None , sample_var = None ):
1774+ def fit (self , X , y , sample_weight = None , freq_weight = None , sample_var = None , groups = None ):
17731775 """
17741776 Fits the model.
17751777
@@ -1788,13 +1790,15 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
17881790 sample_var : {(N,), (N, p)} nd array_like or None
17891791 Variance of the outcome(s) of the original freq_weight[i] observations that were used to
17901792 compute the mean outcome represented by observation i.
1793+ groups : (N,) array_like or None
1794+ Group labels for clustered standard errors.
17911795
17921796 Returns
17931797 -------
17941798 self : StatsModelsLinearRegression
17951799 """
17961800 # TODO: Add other types of covariance estimation (e.g. Newey-West (HAC), HC2, HC3)
1797- X , y , freq_weight , sample_var = self ._check_input (X , y , sample_weight , freq_weight , sample_var )
1801+ X , y , freq_weight , sample_var , groups = self ._check_input (X , y , sample_weight , freq_weight , sample_var , groups )
17981802
17991803 WX = X * np .sqrt (freq_weight ).reshape (- 1 , 1 )
18001804
@@ -1840,6 +1844,8 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
18401844 self .XXXy = np .einsum ('nx,ny->yx' , WX , wy )
18411845 self .XXXX = np .einsum ('nw,nx->wx' , WX , WX )
18421846 self .sample_var = np .average (sv , weights = freq_weight , axis = 0 ) * n_obs
1847+ elif self .cov_type == 'clustered' :
1848+ raise AttributeError ("Clustered standard errors are not supported with federation enabled." )
18431849
18441850 sigma_inv = np .linalg .pinv (self .XX )
18451851
@@ -1871,8 +1877,10 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None):
18711877 for j in range (self ._n_out ):
18721878 weighted_sigma = np .matmul (WX .T , WX * var_i [:, [j ]])
18731879 self ._var .append (correction * np .matmul (sigma_inv , np .matmul (weighted_sigma , sigma_inv )))
1880+ elif (self .cov_type == 'clustered' ):
1881+ self ._var = self ._compute_clustered_variance_linear (WX , y - np .matmul (X , param ), sigma_inv , groups )
18741882 else :
1875- raise AttributeError ("Unsupported cov_type. Must be one of nonrobust, HC0, HC1." )
1883+ raise AttributeError ("Unsupported cov_type. Must be one of nonrobust, HC0, HC1, clustered ." )
18761884
18771885 self ._param_var = np .array (self ._var )
18781886
@@ -1937,7 +1945,6 @@ def aggregate(models: List[StatsModelsLinearRegression]):
19371945 agg_model ._var = correction * np .matmul (sigma_inv , np .matmul (weighted_sigma .squeeze (0 ), sigma_inv ))
19381946 else :
19391947 agg_model ._var = [correction * np .matmul (sigma_inv , np .matmul (ws , sigma_inv )) for ws in weighted_sigma ]
1940-
19411948 else :
19421949 assert agg_model .cov_type == 'nonrobust' or agg_model .cov_type is None
19431950 sigma = XXyy - 2 * np .einsum ('yx,xy->y' , XXXy , param ) + np .einsum ('wx,wy,xy->y' , XXXX , param , param )
@@ -1954,6 +1961,54 @@ def aggregate(models: List[StatsModelsLinearRegression]):
19541961
19551962 return agg_model
19561963
1964+ def _compute_clustered_variance_linear (self , WX , eps_i , sigma_inv , groups ):
1965+ """
1966+ Compute clustered standard errors for linear regression.
1967+
1968+ Parameters
1969+ ----------
1970+ WX : array_like
1971+ Weighted design matrix
1972+ eps_i : array_like
1973+ Residuals
1974+ sigma_inv : array_like
1975+ Inverse of X.T @ X
1976+ groups : array_like
1977+ Group labels for clustering
1978+
1979+ Returns
1980+ -------
1981+ var : array_like or list
1982+ Clustered variance matrix
1983+ """
1984+ n , k = WX .shape
1985+ group_ids , inverse_idx = np .unique (groups , return_inverse = True )
1986+ n_groups = len (group_ids )
1987+
1988+ # Group correction factor
1989+ group_correction = (n_groups / (n_groups - 1 ))
1990+
1991+ if eps_i .ndim < 2 :
1992+ # Single outcome case
1993+ WX_e = WX * eps_i .reshape (- 1 , 1 )
1994+ group_sums = np .zeros ((n_groups , k ))
1995+ np .add .at (group_sums , inverse_idx , WX_e )
1996+ s = group_sums .T @ group_sums
1997+
1998+ return group_correction * np .matmul (sigma_inv , np .matmul (s , sigma_inv ))
1999+ else :
2000+ # Multiple outcome case
2001+ var_list = []
2002+ for j in range (eps_i .shape [1 ]):
2003+ WX_e = WX * eps_i [:, [j ]]
2004+ group_sums = np .zeros ((n_groups , k ))
2005+ np .add .at (group_sums , inverse_idx , WX_e )
2006+ s = group_sums .T @ group_sums
2007+
2008+ var_list .append (group_correction * np .matmul (sigma_inv , np .matmul (s , sigma_inv )))
2009+
2010+ return var_list
2011+
19572012
19582013class StatsModelsRLM (_StatsModelsWrapper ):
19592014 """
@@ -2040,23 +2095,28 @@ class StatsModels2SLS(_StatsModelsWrapper):
20402095
20412096 Parameters
20422097 ----------
2043- cov_type : {'HC0', 'HC1', 'nonrobust', or None}, default 'HC0'
2044- Indicates how the covariance matrix is estimated.
2098+ cov_type : {'HC0', 'HC1', 'nonrobust', 'clustered', or None}, default 'HC0'
2099+ Indicates how the covariance matrix is estimated. 'clustered' requires groups to be provided in fit().
20452100 """
20462101
20472102 def __init__ (self , cov_type = "HC0" ):
20482103 self .fit_intercept = False
20492104 self .cov_type = cov_type
20502105 return
20512106
2052- def _check_input (self , Z , T , y , sample_weight ):
2107+ def _check_input (self , Z , T , y , sample_weight , groups = None ):
20532108 """Check dimensions and other assertions."""
20542109 # set default values for None
20552110 if sample_weight is None :
20562111 sample_weight = np .ones (y .shape [0 ])
2112+ if groups is None :
2113+ groups = np .arange (y .shape [0 ])
2114+ else :
2115+ groups = np .asarray (groups )
20572116
20582117 # check array shape
2059- assert (T .shape [0 ] == Z .shape [0 ] == y .shape [0 ] == sample_weight .shape [0 ]), "Input lengths not compatible!"
2118+ assert (T .shape [0 ] == Z .shape [0 ] == y .shape [0 ] == sample_weight .shape [0 ] == groups .shape [0 ]), \
2119+ "Input lengths not compatible!"
20602120
20612121 # check dimension of instruments is more than dimension of treatments
20622122 if Z .shape [1 ] < T .shape [1 ]:
@@ -2075,7 +2135,7 @@ def _check_input(self, Z, T, y, sample_weight):
20752135 weighted_y = y * np .sqrt (sample_weight ).reshape (- 1 , 1 )
20762136 return weighted_Z , weighted_T , weighted_y
20772137
2078- def fit (self , Z , T , y , sample_weight = None , freq_weight = None , sample_var = None ):
2138+ def fit (self , Z , T , y , sample_weight = None , freq_weight = None , sample_var = None , groups = None ):
20792139 """
20802140 Fits the model.
20812141
@@ -2096,7 +2156,8 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
20962156 sample_var : {(N,), (N, p)} nd array_like or None
20972157 Variance of the outcome(s) of the original freq_weight[i] observations that were used to
20982158 compute the mean outcome represented by observation i.
2099-
2159+ groups : (N,) array_like or None
2160+ Group labels for clustered standard errors. Required when cov_type='clustered'.
21002161
21012162 Returns
21022163 -------
@@ -2105,7 +2166,7 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
21052166 assert freq_weight is None , "freq_weight is not supported yet for this class!"
21062167 assert sample_var is None , "sample_var is not supported yet for this class!"
21072168
2108- Z , T , y = self ._check_input (Z , T , y , sample_weight )
2169+ Z , T , y = self ._check_input (Z , T , y , sample_weight , groups )
21092170
21102171 self ._n_out = 0 if y .ndim < 2 else y .shape [1 ]
21112172
@@ -2164,8 +2225,58 @@ def fit(self, Z, T, y, sample_weight=None, freq_weight=None, sample_var=None):
21642225 weighted_sigma = np .matmul (that .T , that * var_i [:, [j ]])
21652226 self ._var .append (correction * np .matmul (thatT_that_inv ,
21662227 np .matmul (weighted_sigma , thatT_that_inv )))
2228+ elif (self .cov_type == 'clustered' ):
2229+ self ._var = self ._compute_clustered_variance (that , y - np .dot (T , param ), thatT_that_inv , groups )
21672230 else :
2168- raise AttributeError ("Unsupported cov_type. Must be one of nonrobust, HC0, HC1." )
2231+ raise AttributeError ("Unsupported cov_type. Must be one of nonrobust, HC0, HC1, clustered ." )
21692232
21702233 self ._param_var = np .array (self ._var )
21712234 return self
2235+
2236+ def _compute_clustered_variance (self , that , eps_i , thatT_that_inv , groups ):
2237+ """
2238+ Compute clustered standard errors.
2239+
2240+ Parameters
2241+ ----------
2242+ that : array_like
2243+ Fitted values from first stage
2244+ eps_i : array_like
2245+ Residuals
2246+ thatT_that_inv : array_like
2247+ Inverse of that.T @ that
2248+ groups : array_like
2249+ Group labels for clustering
2250+
2251+ Returns
2252+ -------
2253+ var : array_like or list
2254+ Clustered variance matrix
2255+ """
2256+ n , k = that .shape
2257+ group_ids , inverse_idx = np .unique (groups , return_inverse = True )
2258+ n_groups = len (group_ids )
2259+
2260+ # Group correction factor
2261+ group_correction = (n_groups / (n_groups - 1 ))
2262+
2263+ if eps_i .ndim < 2 :
2264+ # Single outcome case
2265+ that_e = that * eps_i .reshape (- 1 , 1 )
2266+ group_sums = np .zeros ((n_groups , k ))
2267+ np .add .at (group_sums , inverse_idx , that_e )
2268+ s = group_sums .T @ group_sums
2269+
2270+ return group_correction * np .matmul (thatT_that_inv , np .matmul (s , thatT_that_inv ))
2271+ else :
2272+ # Multiple outcome case
2273+ var_list = []
2274+ for j in range (eps_i .shape [1 ]):
2275+ that_e = that * eps_i [:, [j ]]
2276+ group_sums = np .zeros ((n_groups , k ))
2277+ np .add .at (group_sums , inverse_idx , that_e )
2278+ s = group_sums .T @ group_sums
2279+
2280+ var_list .append (group_correction * np .matmul (thatT_that_inv , np .matmul (s , thatT_that_inv )))
2281+
2282+ return var_list
0 commit comments