Skip to content

Commit b08b338

Browse files
committed
1. Add support for federated learning in clustered SE computation; add corresponding test
2. Fix clustered SE computation issue with summarized data; add corresponding test Signed-off-by: Mikayel Sukiasyan <[email protected]>
1 parent a04f6a5 commit b08b338

File tree

2 files changed

+163
-21
lines changed

2 files changed

+163
-21
lines changed

econml/sklearn_extensions/linear_model.py

Lines changed: 86 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,7 +1845,31 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None, group
18451845
self.XXXX = np.einsum('nw,nx->wx', WX, WX)
18461846
self.sample_var = np.average(sv, weights=freq_weight, axis=0) * n_obs
18471847
elif self.cov_type == 'clustered':
1848-
raise AttributeError("Clustered standard errors are not supported with federation enabled.")
1848+
group_ids, inverse_idx = np.unique(groups, return_inverse=True)
1849+
n_groups = len(group_ids)
1850+
k = WX.shape[1]
1851+
1852+
S_local = np.einsum('ni,nj->nij', WX, X) # (N, k, k)
1853+
S_flat = S_local.reshape(S_local.shape[0], -1) # (N, k*k)
1854+
group_S_flat = np.zeros((n_groups, k * k))
1855+
np.add.at(group_S_flat, inverse_idx, S_flat)
1856+
group_S = group_S_flat.reshape(n_groups, k, k) # (G, k, k)
1857+
1858+
y2d = y.reshape(-1, 1) if y.ndim < 2 else y # (N, p)
1859+
TY_local = y2d[:, :, None] * WX[:, None, :] # (N, p, k)
1860+
TY_flat = TY_local.reshape(TY_local.shape[0], -1) # (N, p*k)
1861+
group_T_flat = np.zeros((n_groups, y2d.shape[1] * k))
1862+
np.add.at(group_T_flat, inverse_idx, TY_flat)
1863+
group_t = group_T_flat.reshape(n_groups, y2d.shape[1], k).transpose(1, 0, 2) # (p, G, k)
1864+
1865+
TT = np.einsum('ygk,ygl->ykl', group_t, group_t) # (p, k, k)
1866+
ST = np.einsum('gvw,ygx->yvwx', group_S, group_t) # (p, k, k, k)
1867+
SS = np.einsum('gvu,gwx->vuwx', group_S, group_S) # (k, k, k, k)
1868+
1869+
self.CL_TT = TT
1870+
self.CL_ST = ST
1871+
self.CL_SS = SS
1872+
self._n_groups = n_groups
18491873

18501874
sigma_inv = np.linalg.pinv(self.XX)
18511875

@@ -1878,7 +1902,14 @@ def fit(self, X, y, sample_weight=None, freq_weight=None, sample_var=None, group
18781902
weighted_sigma = np.matmul(WX.T, WX * var_i[:, [j]])
18791903
self._var.append(correction * np.matmul(sigma_inv, np.matmul(weighted_sigma, sigma_inv)))
18801904
elif (self.cov_type == 'clustered'):
1881-
self._var = self._compute_clustered_variance_linear(WX, y - np.matmul(X, param), sigma_inv, groups)
1905+
f_weight = np.sqrt(freq_weight) if y.ndim < 2 else np.sqrt(freq_weight).reshape(-1, 1)
1906+
centered_y = y - np.matmul(X, param)
1907+
self._var = self._compute_clustered_variance_linear(
1908+
WX,
1909+
centered_y * f_weight,
1910+
sigma_inv,
1911+
groups
1912+
)
18821913
else:
18831914
raise AttributeError("Unsupported cov_type. Must be one of nonrobust, HC0, HC1, clustered.")
18841915

@@ -1917,11 +1948,6 @@ def aggregate(models: List[StatsModelsLinearRegression]):
19171948

19181949
XX = np.sum([model.XX for model in models], axis=0)
19191950
Xy = np.sum([model.Xy for model in models], axis=0)
1920-
XXyy = np.sum([model.XXyy for model in models], axis=0)
1921-
XXXy = np.sum([model.XXXy for model in models], axis=0)
1922-
XXXX = np.sum([model.XXXX for model in models], axis=0)
1923-
1924-
sample_var = np.sum([model.sample_var for model in models], axis=0)
19251951
n_obs = np.sum([model._n_obs for model in models], axis=0)
19261952

19271953
sigma_inv = np.linalg.pinv(XX)
@@ -1938,27 +1964,66 @@ def aggregate(models: List[StatsModelsLinearRegression]):
19381964
else: # both HC1 and nonrobust use the same correction factor
19391965
correction = (n_obs / (n_obs - df))
19401966

1941-
if agg_model.cov_type in ['HC0', 'HC1']:
1942-
weighted_sigma = XXyy - 2 * np.einsum('yvwx,vy->ywx', XXXy, param) + \
1943-
np.einsum('uvwx,uy,vy->ywx', XXXX, param, param) + sample_var
1967+
(agg_model.XX, agg_model.Xy, agg_model._n_obs) = (XX, Xy, n_obs)
1968+
1969+
if agg_model.cov_type == 'clustered':
1970+
TT = np.sum([m.CL_TT for m in models], axis=0) # (p, k, k)
1971+
ST = np.sum([m.CL_ST for m in models], axis=0) # (p, k, k, k)
1972+
SS = np.sum([m.CL_SS for m in models], axis=0) # (k, k, k, k)
1973+
G = int(np.sum([m._n_groups for m in models])) # total clusters
1974+
1975+
(agg_model.CL_TT, agg_model.CL_ST, agg_model.CL_SS, agg_model._n_groups) = (TT, ST, SS, G)
1976+
1977+
if G <= 1:
1978+
warnings.warn("Number of clusters <= 1. Using biased clustered variance calculation!")
1979+
group_correction = 1.0
1980+
else:
1981+
group_correction = (G / (G - 1))
1982+
1983+
param_T = param.T # (p, k)
1984+
# subtract cross terms of t_g and S_g @ beta
1985+
cross_tmp = np.einsum('yvwu,yw->yvu', ST, param_T) # (p, k, k) with axes (y, v, u)
1986+
cross_left = np.swapaxes(cross_tmp, 1, 2) # (p, k, k) with axes (y, u, v)
1987+
cross_right = np.transpose(cross_left, (0, 2, 1)) # (p, k, k)
1988+
# add quadratic term for (S_g @ beta)(S_g @ beta)^T
1989+
quad = np.einsum('uvwx,yw,yx->yuv',
1990+
np.transpose(SS, (0, 2, 1, 3)),
1991+
param_T,
1992+
param_T)
1993+
S = TT - cross_left - cross_right + quad # (p, k, k)
1994+
19441995
if agg_model._n_out == 0:
1945-
agg_model._var = correction * np.matmul(sigma_inv, np.matmul(weighted_sigma.squeeze(0), sigma_inv))
1996+
V = group_correction * (sigma_inv @ S.squeeze(0) @ sigma_inv)
1997+
agg_model._var = V
19461998
else:
1947-
agg_model._var = [correction * np.matmul(sigma_inv, np.matmul(ws, sigma_inv)) for ws in weighted_sigma]
1999+
agg_model._var = [group_correction * (sigma_inv @ S[j] @ sigma_inv) for j in range(S.shape[0])]
2000+
agg_model._param_var = np.array(agg_model._var)
19482001
else:
1949-
assert agg_model.cov_type == 'nonrobust' or agg_model.cov_type is None
1950-
sigma = XXyy - 2 * np.einsum('yx,xy->y', XXXy, param) + np.einsum('wx,wy,xy->y', XXXX, param, param)
1951-
var_i = (sample_var + sigma) / n_obs
2002+
assert agg_model.cov_type in ['HC0', 'HC1', 'nonrobust', None]
2003+
XXyy = np.sum([model.XXyy for model in models], axis=0)
2004+
XXXy = np.sum([model.XXXy for model in models], axis=0)
2005+
XXXX = np.sum([model.XXXX for model in models], axis=0)
2006+
sample_var = np.sum([model.sample_var for model in models], axis=0)
2007+
2008+
(agg_model.sample_var, agg_model.XXyy, agg_model.XXXy, agg_model.XXXX) = sample_var, XXyy, XXXy, XXXX
2009+
2010+
if agg_model.cov_type in ['HC0', 'HC1']:
2011+
weighted_sigma = XXyy - 2 * np.einsum('yvwx,vy->ywx', XXXy, param) + \
2012+
np.einsum('uvwx,uy,vy->ywx', XXXX, param, param) + sample_var
2013+
matrices = [weighted_sigma.squeeze(0)] if agg_model._n_out == 0 else list(weighted_sigma)
2014+
agg_model._var = [correction * np.matmul(sigma_inv, np.matmul(ws, sigma_inv))
2015+
for ws in matrices]
2016+
else: # non-robust
2017+
sigma = XXyy - 2 * np.einsum('yx,xy->y', XXXy, param) + np.einsum('wx,wy,xy->y', XXXX, param, param)
2018+
var_i = (sample_var + sigma) / n_obs
2019+
matrices = [var_i] if agg_model._n_out == 0 else list(var_i)
2020+
agg_model._var = [correction * var * sigma_inv for var in matrices]
2021+
19522022
if agg_model._n_out == 0:
1953-
agg_model._var = correction * var_i * sigma_inv
1954-
else:
1955-
agg_model._var = [correction * var * sigma_inv for var in var_i]
2023+
agg_model._var = agg_model._var[0]
19562024

19572025
agg_model._param_var = np.array(agg_model._var)
19582026

1959-
(agg_model.XX, agg_model.Xy, agg_model.XXyy, agg_model.XXXy, agg_model.XXXX,
1960-
agg_model.sample_var, agg_model._n_obs) = XX, Xy, XXyy, XXXy, XXXX, sample_var, n_obs
1961-
19622027
return agg_model
19632028

19642029
def _compute_clustered_variance_linear(self, WX, eps_i, sigma_inv, groups):

econml/tests/test_clustered_se.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,80 @@ def test_clustered_se_matches_statsmodels(self):
178178
relative_diff = abs(econml_se - adjusted_sm_se) / adjusted_sm_se
179179
self.assertLess(relative_diff, 1e-4,
180180
f"EconML SE ({econml_se:.8f}) differs from adjusted statsmodels SE ({adjusted_sm_se:.8f})")
181+
182+
def test_clustered_micro_equals_aggregated(self):
183+
"""Test that clustered SE matches for summarized and non-summarized data."""
184+
185+
def _generate_micro_and_aggregated(rng, *, n_groups=12, cells_per_group=6, d=4, p=1):
186+
"""Build a micro dataset and aggregated counterpart with many freq > 1."""
187+
G = n_groups
188+
K = cells_per_group
189+
N = G * K
190+
191+
# Design
192+
X = rng.normal(size=(N, d))
193+
# True coefficients used just to generate data; intercept will be fit by the model
194+
beta_true = rng.normal(size=(d + 1, p))
195+
196+
# Positive sample weights and integer freq weights with many freq > 1
197+
sw = np.exp(rng.normal(scale=0.3, size=N))
198+
freq = rng.integers(1, 6, size=N) # values in {1,2,3,4,5}
199+
200+
# Group labels
201+
groups = np.repeat(np.arange(G), K)
202+
203+
# Build micro outcomes y_{ij}
204+
ybar = np.zeros((N, p), dtype=float)
205+
svar = np.zeros((N, p), dtype=float)
206+
207+
X_micro, y_micro, sw_micro, groups_micro = [], [], [], []
208+
209+
for i in range(N):
210+
f = int(freq[i])
211+
x_i = X[i]
212+
mu_i = np.concatenate(([1.0], x_i)) @ beta_true # shape (p,)
213+
eps = rng.normal(scale=1.0, size=(f, p))
214+
y_ij = mu_i + eps # shape (f, p)
215+
216+
X_micro.append(np.repeat(x_i[None, :], f, axis=0))
217+
y_micro.append(y_ij)
218+
sw_micro.append(np.repeat(sw[i], f))
219+
groups_micro.append(np.repeat(groups[i], f))
220+
221+
ybar[i, :] = y_ij.mean(axis=0)
222+
svar[i, :] = y_ij.var(axis=0, ddof=0)
223+
224+
X_micro = np.vstack(X_micro)
225+
y_micro = np.vstack(y_micro)
226+
sw_micro = np.concatenate(sw_micro)
227+
groups_micro = np.concatenate(groups_micro)
228+
229+
if p == 1:
230+
ybar = ybar.ravel()
231+
svar = svar.ravel()
232+
y_micro = y_micro.ravel()
233+
234+
return (X, ybar, sw, freq, svar, groups), (X_micro, y_micro, sw_micro, groups_micro)
235+
236+
rng = np.random.default_rng(7)
237+
for p in [1, 3]:
238+
(X, ybar, sw, freq, svar, groups), (X_micro, y_micro, sw_micro, groups_micro) = \
239+
_generate_micro_and_aggregated(rng, n_groups=10, cells_per_group=7, d=5, p=p)
240+
241+
m_agg = StatsModelsLinearRegression(fit_intercept=True, cov_type="clustered", enable_federation=False)
242+
m_agg.fit(X, ybar, sample_weight=sw, freq_weight=freq, sample_var=svar, groups=groups)
243+
244+
m_micro = StatsModelsLinearRegression(fit_intercept=True, cov_type="clustered",
245+
enable_federation=False)
246+
m_micro.fit(
247+
X_micro,
248+
y_micro,
249+
sample_weight=sw_micro,
250+
freq_weight=None,
251+
sample_var=None,
252+
groups=groups_micro
253+
)
254+
255+
np.testing.assert_allclose(m_agg._param, m_micro._param, rtol=1e-12, atol=1e-12)
256+
np.testing.assert_allclose(np.array(m_agg._param_var), np.array(m_micro._param_var),
257+
rtol=1e-10, atol=1e-12)

0 commit comments

Comments
 (0)