Skip to content

Commit 7576ff7

Browse files
committed
1. Add cov_options to allow toggling both group_correction and df_correction on/off for clustered covariance
2. Set both corrections on by default 3. Add new test for corrections and modify others with the new corrections defaults
1 parent b08b338 commit 7576ff7

File tree

4 files changed

+123
-41
lines changed

4 files changed

+123
-41
lines changed

econml/inference/_inference.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -465,22 +465,28 @@ class StatsModelsInference(LinearModelFinalInference):
465465
----------
466466
cov_type : str, default 'HC1'
467467
The type of covariance estimation method to use. Supported values are 'nonrobust',
468-
'HC0', 'HC1'.
468+
'HC0', 'HC1', 'clustered'.
469+
cov_options : dict, optional
470+
Additional options for covariance estimation. For clustered covariance, supports:
471+
- 'group_correction': bool, default True. Whether to apply N_G/(N_G-1) correction.
472+
- 'df_correction': bool, default True. Whether to apply (N-1)/(N-K) correction.
469473
"""
470474

471-
def __init__(self, cov_type='HC1'):
472-
if cov_type not in ['nonrobust', 'HC0', 'HC1']:
475+
def __init__(self, cov_type='HC1', cov_options=None):
476+
if cov_type not in ['nonrobust', 'HC0', 'HC1', 'clustered']:
473477
raise ValueError("Unsupported cov_type; "
474478
"must be one of 'nonrobust', "
475-
"'HC0', 'HC1'")
479+
"'HC0', 'HC1', 'clustered'")
476480

477481
self.cov_type = cov_type
482+
self.cov_options = cov_options if cov_options is not None else {}
478483

479484
def prefit(self, estimator, *args, **kwargs):
480485
super().prefit(estimator, *args, **kwargs)
481486
assert not (self.model_final.fit_intercept), ("Inference can only be performed on models linear in "
482487
"their features, but here fit_intercept is True")
483488
self.model_final.cov_type = self.cov_type
489+
self.model_final.cov_options = self.cov_options
484490

485491

486492
class GenericModelFinalInferenceDiscrete(Inference):
@@ -660,21 +666,27 @@ class StatsModelsInferenceDiscrete(LinearModelFinalInferenceDiscrete):
660666
----------
661667
cov_type : str, default 'HC1'
662668
The type of covariance estimation method to use. Supported values are 'nonrobust',
663-
'HC0', 'HC1'.
669+
'HC0', 'HC1', 'clustered'.
670+
cov_options : dict, optional
671+
Additional options for covariance estimation. For clustered covariance, supports:
672+
- 'group_correction': bool, default True. Whether to apply N_G/(N_G-1) correction.
673+
- 'df_correction': bool, default True. Whether to apply (N-1)/(N-K) correction.
664674
"""
665675

666-
def __init__(self, cov_type='HC1'):
667-
if cov_type not in ['nonrobust', 'HC0', 'HC1']:
676+
def __init__(self, cov_type='HC1', cov_options=None):
677+
if cov_type not in ['nonrobust', 'HC0', 'HC1', 'clustered']:
668678
raise ValueError("Unsupported cov_type; "
669679
"must be one of 'nonrobust', "
670-
"'HC0', 'HC1'")
680+
"'HC0', 'HC1', 'clustered'")
671681

672682
self.cov_type = cov_type
683+
self.cov_options = cov_options if cov_options is not None else {}
673684

674685
def prefit(self, estimator, *args, **kwargs):
675686
super().prefit(estimator, *args, **kwargs)
676687
# need to set the fit args before the estimator is fit
677688
self.model_final.cov_type = self.cov_type
689+
self.model_final.cov_options = self.cov_options
678690

679691

680692
class InferenceResults(metaclass=abc.ABCMeta):

econml/iv/dml/_dml.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ def __init__(self, *,
377377
mc_agg='mean',
378378
random_state=None,
379379
allow_missing=False,
380-
cov_type="HC0"):
380+
cov_type="HC0",
381+
cov_options=None):
381382
self.model_y_xw = clone(model_y_xw, safe=False)
382383
self.model_t_xw = clone(model_t_xw, safe=False)
383384
self.model_t_xwz = clone(model_t_xwz, safe=False)
@@ -386,6 +387,7 @@ def __init__(self, *,
386387
self.featurizer = clone(featurizer, safe=False)
387388
self.fit_cate_intercept = fit_cate_intercept
388389
self.cov_type = cov_type
390+
self.cov_options = cov_options if cov_options is not None else {}
389391

390392
super().__init__(discrete_outcome=discrete_outcome,
391393
discrete_instrument=discrete_instrument,
@@ -405,7 +407,7 @@ def _gen_featurizer(self):
405407
return clone(self.featurizer, safe=False)
406408

407409
def _gen_model_final(self):
408-
return StatsModels2SLS(cov_type=self.cov_type)
410+
return StatsModels2SLS(cov_type=self.cov_type, cov_options=self.cov_options)
409411

410412
def _gen_ortho_learner_model_final(self):
411413
return _OrthoIVModelFinal(self._gen_model_final(), self._gen_featurizer(), self.fit_cate_intercept)

econml/sklearn_extensions/linear_model.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,13 +1694,21 @@ class StatsModelsLinearRegression(_StatsModelsWrapper):
16941694
Whether to fit an intercept in this model
16951695
cov_type : string, default "HC0"
16961696
The covariance approach to use. Supported values are "HC0", "HC1", "nonrobust", and "clustered".
1697+
cov_options : dict, optional
1698+
Additional options for covariance estimation. For clustered covariance, supports:
1699+
- 'group_correction': bool, default True. Whether to apply N_G/(N_G-1) correction.
1700+
- 'df_correction': bool, default True. Whether to apply (N-1)/(N-K) correction.
16971701
enable_federation : bool, default False
16981702
Whether to enable federation (aggregating this model's results with other models in a distributed setting).
16991703
This requires additional memory proportional to the number of columns in X to the fourth power.
17001704
"""
17011705

1702-
def __init__(self, fit_intercept=True, cov_type="HC0", *, enable_federation=False):
1706+
def __init__(self, fit_intercept=True, cov_type="HC0", cov_options=None, *, enable_federation=False):
17031707
self.cov_type = cov_type
1708+
self.cov_options = cov_options if cov_options is not None else {}
1709+
if cov_type == 'clustered':
1710+
self.cov_options.setdefault('group_correction', True)
1711+
self.cov_options.setdefault('df_correction', True)
17041712
self.fit_intercept = fit_intercept
17051713
self.enable_federation = enable_federation
17061714

@@ -2050,8 +2058,10 @@ def _compute_clustered_variance_linear(self, WX, eps_i, sigma_inv, groups):
20502058
group_ids, inverse_idx = np.unique(groups, return_inverse=True)
20512059
n_groups = len(group_ids)
20522060

2053-
# Group correction factor
2054-
group_correction = (n_groups / (n_groups - 1))
2061+
# Apply correction factors based on cov_options
2062+
group_correction = (n_groups / (n_groups - 1)) if self.cov_options['group_correction'] else 1.0
2063+
df_correction = ((n - 1) / (n - k)) if self.cov_options['df_correction'] else 1.0
2064+
correction = group_correction * df_correction
20552065

20562066
if eps_i.ndim < 2:
20572067
# Single outcome case
@@ -2060,7 +2070,7 @@ def _compute_clustered_variance_linear(self, WX, eps_i, sigma_inv, groups):
20602070
np.add.at(group_sums, inverse_idx, WX_e)
20612071
s = group_sums.T @ group_sums
20622072

2063-
return group_correction * np.matmul(sigma_inv, np.matmul(s, sigma_inv))
2073+
return correction * np.matmul(sigma_inv, np.matmul(s, sigma_inv))
20642074
else:
20652075
# Multiple outcome case
20662076
var_list = []
@@ -2070,7 +2080,7 @@ def _compute_clustered_variance_linear(self, WX, eps_i, sigma_inv, groups):
20702080
np.add.at(group_sums, inverse_idx, WX_e)
20712081
s = group_sums.T @ group_sums
20722082

2073-
var_list.append(group_correction * np.matmul(sigma_inv, np.matmul(s, sigma_inv)))
2083+
var_list.append(correction * np.matmul(sigma_inv, np.matmul(s, sigma_inv)))
20742084

20752085
return var_list
20762086

@@ -2162,11 +2172,19 @@ class StatsModels2SLS(_StatsModelsWrapper):
21622172
----------
21632173
cov_type : {'HC0', 'HC1', 'nonrobust', 'clustered', or None}, default 'HC0'
21642174
Indicates how the covariance matrix is estimated. 'clustered' requires groups to be provided in fit().
2175+
cov_options : dict, optional
2176+
Additional options for covariance estimation. For clustered covariance, supports:
2177+
- 'group_correction': bool, default True. Whether to apply N_G/(N_G-1) correction.
2178+
- 'df_correction': bool, default True. Whether to apply (N-1)/(N-K) correction.
21652179
"""
21662180

2167-
def __init__(self, cov_type="HC0"):
2181+
def __init__(self, cov_type="HC0", cov_options=None):
21682182
self.fit_intercept = False
21692183
self.cov_type = cov_type
2184+
self.cov_options = cov_options if cov_options is not None else {}
2185+
if cov_type == 'clustered':
2186+
self.cov_options.setdefault('group_correction', True)
2187+
self.cov_options.setdefault('df_correction', True)
21702188
return
21712189

21722190
def _check_input(self, Z, T, y, sample_weight, groups=None):
@@ -2322,8 +2340,10 @@ def _compute_clustered_variance(self, that, eps_i, thatT_that_inv, groups):
23222340
group_ids, inverse_idx = np.unique(groups, return_inverse=True)
23232341
n_groups = len(group_ids)
23242342

2325-
# Group correction factor
2326-
group_correction = (n_groups / (n_groups - 1))
2343+
# Apply correction factors based on cov_options
2344+
group_correction = (n_groups / (n_groups - 1)) if self.cov_options['group_correction'] else 1.0
2345+
df_correction = ((n - 1) / (n - k)) if self.cov_options['df_correction'] else 1.0
2346+
correction = group_correction * df_correction
23272347

23282348
if eps_i.ndim < 2:
23292349
# Single outcome case
@@ -2332,7 +2352,7 @@ def _compute_clustered_variance(self, that, eps_i, thatT_that_inv, groups):
23322352
np.add.at(group_sums, inverse_idx, that_e)
23332353
s = group_sums.T @ group_sums
23342354

2335-
return group_correction * np.matmul(thatT_that_inv, np.matmul(s, thatT_that_inv))
2355+
return correction * np.matmul(thatT_that_inv, np.matmul(s, thatT_that_inv))
23362356
else:
23372357
# Multiple outcome case
23382358
var_list = []
@@ -2342,6 +2362,6 @@ def _compute_clustered_variance(self, that, eps_i, thatT_that_inv, groups):
23422362
np.add.at(group_sums, inverse_idx, that_e)
23432363
s = group_sums.T @ group_sums
23442364

2345-
var_list.append(group_correction * np.matmul(thatT_that_inv, np.matmul(s, thatT_that_inv)))
2365+
var_list.append(correction * np.matmul(thatT_that_inv, np.matmul(s, thatT_that_inv)))
23462366

23472367
return var_list

econml/tests/test_clustered_se.py

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_clustered_se_without_groups_defaults_to_individual(self):
110110
T = np.random.binomial(1, 0.5, n)
111111
Y = np.random.normal(0, 1, n)
112112

113-
# Clustered SE without groups (defaults to individual groups)
113+
# Clustered SE with default corrections (both enabled)
114114
np.random.seed(123)
115115
est_clustered = DML(model_y=LassoCV(), model_t=LogisticRegression(),
116116
model_final=StatsModelsLinearRegression(fit_intercept=False, cov_type='clustered'),
@@ -129,16 +129,13 @@ def test_clustered_se_without_groups_defaults_to_individual(self):
129129
lb_clustered, ub_clustered = est_clustered.effect_interval(X_test, alpha=0.05)
130130
lb_hc0, ub_hc0 = est_hc0.effect_interval(X_test, alpha=0.05)
131131

132-
# Clustered SE should be HC0 SE * sqrt(n/(n-1)) when each obs is its own cluster
133-
# Width of confidence intervals should differ by the adjustment factor
134-
width_clustered = ub_clustered - lb_clustered
135-
width_hc0 = ub_hc0 - lb_hc0
136-
137-
# When each observation is its own cluster, clustered SE should equal HC0 * sqrt(n/(n-1))
138-
# due to the finite sample correction factor
139-
correction_factor = np.sqrt(n / (n - 1))
140-
expected_width = width_hc0 * correction_factor
141-
np.testing.assert_allclose(width_clustered, expected_width, rtol=1e-10)
132+
# With both corrections: sqrt(n/(n-1)) * sqrt((n-1)/(n-k)) = sqrt(n/(n-k))
133+
# Get k from the fitted model (includes treatment variable)
134+
k_params = est_clustered.model_final_.coef_.shape[0]
135+
correction_factor = np.sqrt(n / (n - k_params))
136+
expected_width = (ub_hc0 - lb_hc0) * correction_factor
137+
actual_width = ub_clustered - lb_clustered
138+
np.testing.assert_allclose(actual_width, expected_width, rtol=1e-10)
142139

143140
# Test basic functionality still works
144141
effects = est_clustered.effect(X_test)
@@ -169,15 +166,11 @@ def test_clustered_se_matches_statsmodels(self):
169166
sm_model = sm.OLS(Y, X_with_intercept).fit(cov_type='cluster', cov_kwds={'groups': groups})
170167
sm_se = sm_model.bse[1] # SE for X[:, 0] coefficient
171168

172-
# Account for statsmodels' additional n/(n-k) adjustment
173-
k = X_with_intercept.shape[1] # Number of parameters
174-
sm_adjustment = np.sqrt((n - 1) / (n - k))
175-
adjusted_sm_se = sm_se / sm_adjustment
176-
177-
# Should match very closely
178-
relative_diff = abs(econml_se - adjusted_sm_se) / adjusted_sm_se
169+
# Statsmodels applies both G/(G-1) and (N-1)/(N-K) corrections by default
170+
# Our implementation also applies both by default, so they should match
171+
relative_diff = abs(econml_se - sm_se) / sm_se
179172
self.assertLess(relative_diff, 1e-4,
180-
f"EconML SE ({econml_se:.8f}) differs from adjusted statsmodels SE ({adjusted_sm_se:.8f})")
173+
f"EconML SE ({econml_se:.8f}) differs from statsmodels SE ({sm_se:.8f})")
181174

182175
def test_clustered_micro_equals_aggregated(self):
183176
"""Test that clustered SE matches for summarized and non-summarized data."""
@@ -238,11 +231,14 @@ def _generate_micro_and_aggregated(rng, *, n_groups=12, cells_per_group=6, d=4,
238231
(X, ybar, sw, freq, svar, groups), (X_micro, y_micro, sw_micro, groups_micro) = \
239232
_generate_micro_and_aggregated(rng, n_groups=10, cells_per_group=7, d=5, p=p)
240233

241-
m_agg = StatsModelsLinearRegression(fit_intercept=True, cov_type="clustered", enable_federation=False)
234+
# Disable DF correction since n differs between aggregated and micro datasets
235+
cov_opts = {'group_correction': True, 'df_correction': False}
236+
m_agg = StatsModelsLinearRegression(fit_intercept=True, cov_type="clustered",
237+
cov_options=cov_opts, enable_federation=False)
242238
m_agg.fit(X, ybar, sample_weight=sw, freq_weight=freq, sample_var=svar, groups=groups)
243239

244240
m_micro = StatsModelsLinearRegression(fit_intercept=True, cov_type="clustered",
245-
enable_federation=False)
241+
cov_options=cov_opts, enable_federation=False)
246242
m_micro.fit(
247243
X_micro,
248244
y_micro,
@@ -255,3 +251,55 @@ def _generate_micro_and_aggregated(rng, *, n_groups=12, cells_per_group=6, d=4,
255251
np.testing.assert_allclose(m_agg._param, m_micro._param, rtol=1e-12, atol=1e-12)
256252
np.testing.assert_allclose(np.array(m_agg._param_var), np.array(m_micro._param_var),
257253
rtol=1e-10, atol=1e-12)
254+
255+
def test_clustered_correction_factors(self):
256+
"""Test that correction factors are applied correctly."""
257+
np.random.seed(42)
258+
n = 200
259+
n_groups = 20
260+
X = np.random.randn(n, 3)
261+
groups = np.repeat(np.arange(n_groups), n // n_groups)
262+
y = X[:, 0] + 0.5 * X[:, 1] + np.random.randn(n) * 0.5
263+
264+
# Fit models with different correction options
265+
m_none = StatsModelsLinearRegression(
266+
cov_type='clustered',
267+
cov_options={'group_correction': False, 'df_correction': False}
268+
).fit(X, y, groups=groups)
269+
270+
m_group = StatsModelsLinearRegression(
271+
cov_type='clustered',
272+
cov_options={'group_correction': True, 'df_correction': False}
273+
).fit(X, y, groups=groups)
274+
275+
m_df = StatsModelsLinearRegression(
276+
cov_type='clustered',
277+
cov_options={'group_correction': False, 'df_correction': True}
278+
).fit(X, y, groups=groups)
279+
280+
m_both = StatsModelsLinearRegression(
281+
cov_type='clustered',
282+
cov_options={'group_correction': True, 'df_correction': True}
283+
).fit(X, y, groups=groups)
284+
285+
# Get actual number of parameters
286+
k_params = len(m_none.coef_) + 1
287+
288+
# Verify group correction
289+
group_ratio = m_group.coef_stderr_ / m_none.coef_stderr_
290+
expected_group_ratio = np.sqrt(n_groups / (n_groups - 1))
291+
np.testing.assert_allclose(group_ratio, expected_group_ratio, rtol=1e-10)
292+
293+
# Verify DF correction
294+
df_ratio = m_df.coef_stderr_ / m_none.coef_stderr_
295+
expected_df_ratio = np.sqrt((n - 1) / (n - k_params))
296+
np.testing.assert_allclose(df_ratio, expected_df_ratio, rtol=1e-10)
297+
298+
# Verify combined correction
299+
combined_ratio = m_both.coef_stderr_ / m_none.coef_stderr_
300+
expected_combined_ratio = np.sqrt(n_groups / (n_groups - 1) * (n - 1) / (n - k_params))
301+
np.testing.assert_allclose(combined_ratio, expected_combined_ratio, rtol=1e-10)
302+
303+
# Verify multiplicative property
304+
both_from_components = m_group.coef_stderr_ * m_df.coef_stderr_ / m_none.coef_stderr_
305+
np.testing.assert_allclose(m_both.coef_stderr_, both_from_components, rtol=1e-10)

0 commit comments

Comments
 (0)