@@ -1845,7 +1845,31 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None, group
18451845 self .XXXX = np .einsum ('nw,nx->wx' , WX , WX )
18461846 self .sample_var = np .average (sv , weights = freq_weight , axis = 0 ) * n_obs
18471847 elif self .cov_type == 'clustered' :
1848- raise AttributeError ("Clustered standard errors are not supported with federation enabled." )
1848+ group_ids , inverse_idx = np .unique (groups , return_inverse = True )
1849+ n_groups = len (group_ids )
1850+ k = WX .shape [1 ]
1851+
1852+ S_local = np .einsum ('ni,nj->nij' , WX , X ) # (N, k, k)
1853+ S_flat = S_local .reshape (S_local .shape [0 ], - 1 ) # (N, k*k)
1854+ group_S_flat = np .zeros ((n_groups , k * k ))
1855+ np .add .at (group_S_flat , inverse_idx , S_flat )
1856+ group_S = group_S_flat .reshape (n_groups , k , k ) # (G, k, k)
1857+
1858+ y2d = y .reshape (- 1 , 1 ) if y .ndim < 2 else y # (N, p)
1859+ TY_local = y2d [:, :, None ] * WX [:, None , :] # (N, p, k)
1860+ TY_flat = TY_local .reshape (TY_local .shape [0 ], - 1 ) # (N, p*k)
1861+ group_T_flat = np .zeros ((n_groups , y2d .shape [1 ] * k ))
1862+ np .add .at (group_T_flat , inverse_idx , TY_flat )
1863+ group_t = group_T_flat .reshape (n_groups , y2d .shape [1 ], k ).transpose (1 , 0 , 2 ) # (p, G, k)
1864+
1865+ TT = np .einsum ('ygk,ygl->ykl' , group_t , group_t ) # (p, k, k)
1866+ ST = np .einsum ('gvw,ygx->yvwx' , group_S , group_t ) # (p, k, k, k)
1867+ SS = np .einsum ('gvu,gwx->vuwx' , group_S , group_S ) # (k, k, k, k)
1868+
1869+ self .CL_TT = TT
1870+ self .CL_ST = ST
1871+ self .CL_SS = SS
1872+ self ._n_groups = n_groups
18491873
18501874 sigma_inv = np .linalg .pinv (self .XX )
18511875
@@ -1878,7 +1902,14 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None, group
18781902 weighted_sigma = np .matmul (WX .T , WX * var_i [:, [j ]])
18791903 self ._var .append (correction * np .matmul (sigma_inv , np .matmul (weighted_sigma , sigma_inv )))
18801904 elif (self .cov_type == 'clustered' ):
1881- self ._var = self ._compute_clustered_variance_linear (WX , y - np .matmul (X , param ), sigma_inv , groups )
1905+ f_weight = np .sqrt (freq_weight ) if y .ndim < 2 else np .sqrt (freq_weight ).reshape (- 1 , 1 )
1906+ centered_y = y - np .matmul (X , param )
1907+ self ._var = self ._compute_clustered_variance_linear (
1908+ WX ,
1909+ centered_y * f_weight ,
1910+ sigma_inv ,
1911+ groups
1912+ )
18821913 else :
18831914 raise AttributeError ("Unsupported cov_type. Must be one of nonrobust, HC0, HC1, clustered." )
18841915
@@ -1917,11 +1948,6 @@ def aggregate(models: List[StatsModelsLinearRegression]):
19171948
19181949 XX = np .sum ([model .XX for model in models ], axis = 0 )
19191950 Xy = np .sum ([model .Xy for model in models ], axis = 0 )
1920- XXyy = np .sum ([model .XXyy for model in models ], axis = 0 )
1921- XXXy = np .sum ([model .XXXy for model in models ], axis = 0 )
1922- XXXX = np .sum ([model .XXXX for model in models ], axis = 0 )
1923-
1924- sample_var = np .sum ([model .sample_var for model in models ], axis = 0 )
19251951 n_obs = np .sum ([model ._n_obs for model in models ], axis = 0 )
19261952
19271953 sigma_inv = np .linalg .pinv (XX )
@@ -1938,27 +1964,66 @@ def aggregate(models: List[StatsModelsLinearRegression]):
19381964 else : # both HC1 and nonrobust use the same correction factor
19391965 correction = (n_obs / (n_obs - df ))
19401966
1941- if agg_model .cov_type in ['HC0' , 'HC1' ]:
1942- weighted_sigma = XXyy - 2 * np .einsum ('yvwx,vy->ywx' , XXXy , param ) + \
1943- np .einsum ('uvwx,uy,vy->ywx' , XXXX , param , param ) + sample_var
1967+ (agg_model .XX , agg_model .Xy , agg_model ._n_obs ) = (XX , Xy , n_obs )
1968+
1969+ if agg_model .cov_type == 'clustered' :
1970+ TT = np .sum ([m .CL_TT for m in models ], axis = 0 ) # (p, k, k)
1971+ ST = np .sum ([m .CL_ST for m in models ], axis = 0 ) # (p, k, k, k)
1972+ SS = np .sum ([m .CL_SS for m in models ], axis = 0 ) # (k, k, k, k)
1973+ G = int (np .sum ([m ._n_groups for m in models ])) # total clusters
1974+
1975+ (agg_model .CL_TT , agg_model .CL_ST , agg_model .CL_SS , agg_model ._n_groups ) = (TT , ST , SS , G )
1976+
1977+ if G <= 1 :
1978+ warnings .warn ("Number of clusters <= 1. Using biased clustered variance calculation!" )
1979+ group_correction = 1.0
1980+ else :
1981+ group_correction = (G / (G - 1 ))
1982+
1983+ param_T = param .T # (p, k)
1984+ # subtract cross terms of t_g and S_g @ beta
1985+ cross_tmp = np .einsum ('yvwu,yw->yvu' , ST , param_T ) # (p, k, k) with axes (y, v, u)
1986+ cross_left = np .swapaxes (cross_tmp , 1 , 2 ) # (p, k, k) with axes (y, u, v)
1987+ cross_right = np .transpose (cross_left , (0 , 2 , 1 )) # (p, k, k)
1988+ # add quadratic term for (S_g @ beta)(S_g @ beta)^T
1989+ quad = np .einsum ('uvwx,yw,yx->yuv' ,
1990+ np .transpose (SS , (0 , 2 , 1 , 3 )),
1991+ param_T ,
1992+ param_T )
1993+ S = TT - cross_left - cross_right + quad # (p, k, k)
1994+
19441995 if agg_model ._n_out == 0 :
1945- agg_model ._var = correction * np .matmul (sigma_inv , np .matmul (weighted_sigma .squeeze (0 ), sigma_inv ))
1996+ V = group_correction * (sigma_inv @ S .squeeze (0 ) @ sigma_inv )
1997+ agg_model ._var = V
19461998 else :
1947- agg_model ._var = [correction * np .matmul (sigma_inv , np .matmul (ws , sigma_inv )) for ws in weighted_sigma ]
1999+ agg_model ._var = [group_correction * (sigma_inv @ S [j ] @ sigma_inv ) for j in range (S .shape [0 ])]
2000+ agg_model ._param_var = np .array (agg_model ._var )
19482001 else :
1949- assert agg_model .cov_type == 'nonrobust' or agg_model .cov_type is None
1950- sigma = XXyy - 2 * np .einsum ('yx,xy->y' , XXXy , param ) + np .einsum ('wx,wy,xy->y' , XXXX , param , param )
1951- var_i = (sample_var + sigma ) / n_obs
2002+ assert agg_model .cov_type in ['HC0' , 'HC1' , 'nonrobust' , None ]
2003+ XXyy = np .sum ([model .XXyy for model in models ], axis = 0 )
2004+ XXXy = np .sum ([model .XXXy for model in models ], axis = 0 )
2005+ XXXX = np .sum ([model .XXXX for model in models ], axis = 0 )
2006+ sample_var = np .sum ([model .sample_var for model in models ], axis = 0 )
2007+
2008+ (agg_model .sample_var , agg_model .XXyy , agg_model .XXXy , agg_model .XXXX ) = sample_var , XXyy , XXXy , XXXX
2009+
2010+ if agg_model .cov_type in ['HC0' , 'HC1' ]:
2011+ weighted_sigma = XXyy - 2 * np .einsum ('yvwx,vy->ywx' , XXXy , param ) + \
2012+ np .einsum ('uvwx,uy,vy->ywx' , XXXX , param , param ) + sample_var
2013+ matrices = [weighted_sigma .squeeze (0 )] if agg_model ._n_out == 0 else list (weighted_sigma )
2014+ agg_model ._var = [correction * np .matmul (sigma_inv , np .matmul (ws , sigma_inv ))
2015+ for ws in matrices ]
2016+ else : # non-robust
2017+ sigma = XXyy - 2 * np .einsum ('yx,xy->y' , XXXy , param ) + np .einsum ('wx,wy,xy->y' , XXXX , param , param )
2018+ var_i = (sample_var + sigma ) / n_obs
2019+ matrices = [var_i ] if agg_model ._n_out == 0 else list (var_i )
2020+ agg_model ._var = [correction * var * sigma_inv for var in matrices ]
2021+
19522022 if agg_model ._n_out == 0 :
1953- agg_model ._var = correction * var_i * sigma_inv
1954- else :
1955- agg_model ._var = [correction * var * sigma_inv for var in var_i ]
2023+ agg_model ._var = agg_model ._var [0 ]
19562024
19572025 agg_model ._param_var = np .array (agg_model ._var )
19582026
1959- (agg_model .XX , agg_model .Xy , agg_model .XXyy , agg_model .XXXy , agg_model .XXXX ,
1960- agg_model .sample_var , agg_model ._n_obs ) = XX , Xy , XXyy , XXXy , XXXX , sample_var , n_obs
1961-
19622027 return agg_model
19632028
19642029 def _compute_clustered_variance_linear (self , WX , eps_i , sigma_inv , groups ):
0 commit comments