Skip to content

Commit ce2f2b5

Browse files
authored
some small fixes to the debiased lasso (#358)
* some small fixes to the debiased lasso * added parallelism across rows of design matrix to run each lassocv in parallel. added n_jobs param to debiased lasso and to sparselineardml * added n_jobs to multioutput debiasedlasso * added separate options for alpha for the covariance matrxi estimation * added extra alpha options in sparselineardrlearner and sparsellineardml
1 parent bb042d5 commit ce2f2b5

File tree

5 files changed

+218
-128
lines changed

5 files changed

+218
-128
lines changed

econml/dml/dml.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,18 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
653653
CATE L1 regularization applied through the debiased lasso in the final model.
654654
'auto' corresponds to a CV form of the :class:`MultiOutputDebiasedLasso`.
655655
656+
n_alphas : int, optional, default 100
657+
How many alphas to try if alpha='auto'
658+
659+
alpha_cov : string | float, optional, default 'auto'
660+
The regularization alpha that is used when constructing the pseudo inverse of
661+
the covariance matrix Theta used to for correcting the final state lasso coefficient
662+
in the debiased lasso. Each such regression corresponds to the regression of one feature
663+
on the remainder of the features.
664+
665+
n_alphas_cov : int, optional, default 10
666+
How many alpha_cov to try if alpha_cov='auto'.
667+
656668
max_iter : int, optional, default=1000
657669
The maximum number of iterations in the Debiased Lasso
658670
@@ -707,8 +719,12 @@ class SparseLinearDML(DebiasedLassoCateEstimatorMixin, DML):
707719
def __init__(self,
708720
model_y='auto', model_t='auto',
709721
alpha='auto',
722+
n_alphas=100,
723+
alpha_cov='auto',
724+
n_alphas_cov=10,
710725
max_iter=1000,
711726
tol=1e-4,
727+
n_jobs=None,
712728
featurizer=None,
713729
fit_cate_intercept=True,
714730
linear_first_stages=True,
@@ -718,9 +734,13 @@ def __init__(self,
718734
random_state=None):
719735
model_final = MultiOutputDebiasedLasso(
720736
alpha=alpha,
737+
n_alphas=n_alphas,
738+
alpha_cov=alpha_cov,
739+
n_alphas_cov=n_alphas_cov,
721740
fit_intercept=False,
722741
max_iter=max_iter,
723742
tol=tol,
743+
n_jobs=n_jobs,
724744
random_state=random_state)
725745
super().__init__(model_y=model_y,
726746
model_t=model_t,

econml/drlearner.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,18 @@ class SparseLinearDRLearner(DebiasedLassoCateEstimatorDiscreteMixin, DRLearner):
853853
CATE L1 regularization applied through the debiased lasso in the final model.
854854
'auto' corresponds to a CV form of the :class:`DebiasedLasso`.
855855
856+
n_alphas : int, optional, default 100
857+
How many alphas to try if alpha='auto'
858+
859+
alpha_cov : string | float, optional, default 'auto'
860+
The regularization alpha that is used when constructing the pseudo inverse of
861+
the covariance matrix Theta used to for correcting the final state lasso coefficient
862+
in the debiased lasso. Each such regression corresponds to the regression of one feature
863+
on the remainder of the features.
864+
865+
n_alphas_cov : int, optional, default 10
866+
How many alpha_cov to try if alpha_cov='auto'.
867+
856868
max_iter : int, optional, default 1000
857869
The maximum number of iterations in the Debiased Lasso
858870
@@ -910,17 +922,17 @@ class SparseLinearDRLearner(DebiasedLassoCateEstimatorDiscreteMixin, DRLearner):
910922
est.fit(y, T, X=X, W=None)
911923
912924
>>> est.effect(X[:3])
913-
array([ 0.418400..., 0.306400..., -0.130733...])
925+
array([ 0.41..., 0.31..., -0.12...])
914926
>>> est.effect_interval(X[:3])
915-
(array([ 0.056783..., -0.206438..., -0.739296...]), array([0.780017..., 0.819239..., 0.477828...]))
927+
(array([ 0.04..., -0.19..., -0.73...]), array([0.77..., 0.82..., 0.47...]))
916928
>>> est.coef_(T=1)
917-
array([0.449779..., 0.004807..., 0.061954...])
929+
array([ 0.45..., -0.00..., 0.06...])
918930
>>> est.coef__interval(T=1)
919-
(array([ 0.242194... , -0.190825..., -0.139646...]), array([0.657365..., 0.200440..., 0.263556...]))
931+
(array([ 0.24... , -0.19..., -0.13...]), array([0.65..., 0.19..., 0.26...]))
920932
>>> est.intercept_(T=1)
921-
0.88436847...
933+
0.88...
922934
>>> est.intercept__interval(T=1)
923-
(0.68683788..., 1.08189907...)
935+
(0.68..., 1.08...)
924936
925937
Attributes
926938
----------
@@ -942,17 +954,25 @@ def __init__(self,
942954
featurizer=None,
943955
fit_cate_intercept=True,
944956
alpha='auto',
957+
n_alphas=100,
958+
alpha_cov='auto',
959+
n_alphas_cov=10,
945960
max_iter=1000,
946961
tol=1e-4,
947962
min_propensity=1e-6,
948963
categories='auto',
949-
n_splits=2, random_state=None):
964+
n_splits=2,
965+
random_state=None):
950966
self.fit_cate_intercept = fit_cate_intercept
951967
model_final = DebiasedLasso(
952968
alpha=alpha,
969+
n_alphas=n_alphas,
970+
alpha_cov=alpha_cov,
971+
n_alphas_cov=n_alphas_cov,
953972
fit_intercept=fit_cate_intercept,
954973
max_iter=max_iter,
955-
tol=tol)
974+
tol=tol,
975+
random_state=random_state)
956976
super().__init__(model_propensity=model_propensity,
957977
model_regression=model_regression,
958978
model_final=model_final,

econml/sklearn_extensions/linear_model.py

Lines changed: 89 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from statsmodels.tools.tools import add_constant
3636
from statsmodels.api import RLM
3737
import statsmodels
38+
from joblib import Parallel, delayed
3839

3940

4041
def _weighted_check_cv(cv=5, y=None, classifier=False):
@@ -539,6 +540,34 @@ def fit(self, X, y, sample_weight=None):
539540
return self
540541

541542

543+
def _get_theta_coefs_and_tau_sq(i, X, sample_weight, alpha_cov, n_alphas_cov, max_iter, tol, random_state):
544+
n_samples, n_features = X.shape
545+
y = X[:, i]
546+
X_reduced = X[:, list(range(i)) + list(range(i + 1, n_features))]
547+
# Call weighted lasso on reduced design matrix
548+
if alpha_cov == 'auto':
549+
local_wlasso = WeightedLassoCV(cv=3, n_alphas=n_alphas_cov,
550+
fit_intercept=False,
551+
max_iter=max_iter,
552+
tol=tol, n_jobs=1,
553+
random_state=random_state)
554+
else:
555+
local_wlasso = WeightedLasso(alpha=alpha_cov,
556+
fit_intercept=False,
557+
max_iter=max_iter,
558+
tol=tol,
559+
random_state=random_state)
560+
local_wlasso.fit(X_reduced, y, sample_weight=sample_weight)
561+
coefs = local_wlasso.coef_
562+
# Weighted tau
563+
if sample_weight is not None:
564+
y_weighted = y * sample_weight / np.sum(sample_weight)
565+
else:
566+
y_weighted = y / n_samples
567+
tausq = np.dot(y - local_wlasso.predict(X_reduced), y_weighted)
568+
return coefs, tausq
569+
570+
542571
class DebiasedLasso(WeightedLasso):
543572
"""Debiased Lasso model.
544573
@@ -555,6 +584,18 @@ class DebiasedLasso(WeightedLasso):
555584
reasons, using ``alpha = 0`` with the ``Lasso`` object is not advised.
556585
Given this, you should use the :class:`.LinearRegression` object.
557586
587+
n_alphas : int, optional, default 100
588+
How many alphas to try if alpha='auto'
589+
590+
alpha_cov : string | float, optional, default 'auto'
591+
The regularization alpha that is used when constructing the pseudo inverse of
592+
the covariance matrix Theta used to for correcting the lasso coefficient. Each
593+
such regression corresponds to the regression of one feature on the remainder
594+
of the features.
595+
596+
n_alphas_cov : int, optional, default 10
597+
How many alpha_cov to try if alpha_cov='auto'.
598+
558599
fit_intercept : boolean, optional, default True
559600
Whether to calculate the intercept for this model. If set
560601
to False, no intercept will be used in calculations
@@ -597,6 +638,9 @@ class DebiasedLasso(WeightedLasso):
597638
(setting to 'random') often leads to significantly faster convergence
598639
especially when tol is higher than 1e-4.
599640
641+
n_jobs : int or None, default None
642+
How many jobs to use whenever parallelism is invoked
643+
600644
Attributes
601645
----------
602646
coef_ : array, shape (n_features,)
@@ -620,10 +664,14 @@ class DebiasedLasso(WeightedLasso):
620664
621665
"""
622666

623-
def __init__(self, alpha='auto', fit_intercept=True,
624-
precompute=False, copy_X=True, max_iter=1000,
667+
def __init__(self, alpha='auto', n_alphas=100, alpha_cov='auto', n_alphas_cov=10,
668+
fit_intercept=True, precompute=False, copy_X=True, max_iter=1000,
625669
tol=1e-4, warm_start=False,
626-
random_state=None, selection='cyclic'):
670+
random_state=None, selection='cyclic', n_jobs=None):
671+
self.n_jobs = n_jobs
672+
self.n_alphas = n_alphas
673+
self.alpha_cov = alpha_cov
674+
self.n_alphas_cov = n_alphas_cov
627675
super().__init__(
628676
alpha=alpha, fit_intercept=fit_intercept,
629677
precompute=precompute, copy_X=copy_X,
@@ -747,18 +795,8 @@ def predict_interval(self, X, alpha=0.1):
747795
lower = alpha / 2
748796
upper = 1 - alpha / 2
749797
y_pred = self.predict(X)
750-
y_lower = np.empty(y_pred.shape)
751-
y_upper = np.empty(y_pred.shape)
752-
# Note that in the case of no intercept, X_offset is 0
753-
if self.fit_intercept:
754-
X = X - self._X_offset
755-
# Calculate the variance of the predictions
756-
var_pred = np.sum(np.matmul(X, self._coef_variance) * X, axis=1)
757-
if self.fit_intercept:
758-
var_pred += self._mean_error_variance
759-
760798
# Calculate prediction confidence intervals
761-
sd_pred = np.sqrt(var_pred)
799+
sd_pred = self.prediction_stderr(X)
762800
y_lower = y_pred + \
763801
np.apply_along_axis(lambda s: norm.ppf(
764802
lower, scale=s), 0, sd_pred)
@@ -810,20 +848,25 @@ def intercept__interval(self, alpha=0.1):
810848

811849
def _get_coef_correction(self, X, y, y_pred, sample_weight, theta_hat):
812850
# Assumes flattened y
813-
n_samples, n_features = X.shape
851+
n_samples, _ = X.shape
814852
y_res = np.ndarray.flatten(y) - y_pred
815853
# Compute weighted residuals
816854
if sample_weight is not None:
817855
y_res_scaled = y_res * sample_weight / np.sum(sample_weight)
818856
else:
819857
y_res_scaled = y_res / n_samples
820858
delta_coef = np.matmul(
821-
np.matmul(theta_hat, X.T), y_res_scaled)
859+
theta_hat, np.matmul(X.T, y_res_scaled))
822860
return delta_coef
823861

824862
def _get_optimal_alpha(self, X, y, sample_weight):
825863
# To be done once per target. Assumes y can be flattened.
826-
cv_estimator = WeightedLassoCV(cv=5, fit_intercept=self.fit_intercept)
864+
cv_estimator = WeightedLassoCV(cv=5, n_alphas=self.n_alphas, fit_intercept=self.fit_intercept,
865+
precompute=self.precompute, copy_X=True,
866+
max_iter=self.max_iter, tol=self.tol,
867+
random_state=self.random_state,
868+
selection=self.selection,
869+
n_jobs=self.n_jobs)
827870
cv_estimator.fit(X, y.flatten(), sample_weight=sample_weight)
828871
return cv_estimator.alpha_
829872

@@ -835,27 +878,15 @@ def _get_theta_hat(self, X, sample_weight):
835878
C_hat = np.ones((1, 1))
836879
tausq = (X.T @ X / n_samples).flatten()
837880
return np.diag(1 / tausq) @ C_hat
838-
coefs = np.empty((n_features, n_features - 1))
839-
tausq = np.empty(n_features)
840881
# Compute Lasso coefficients for the columns of the design matrix
841-
for i in range(n_features):
842-
y = X[:, i]
843-
X_reduced = X[:, list(range(i)) + list(range(i + 1, n_features))]
844-
# Call weighted lasso on reduced design matrix
845-
# Inherit some parameters from the parent
846-
local_wlasso = WeightedLasso(
847-
alpha=self.alpha,
848-
fit_intercept=False,
849-
max_iter=self.max_iter,
850-
tol=self.tol
851-
).fit(X_reduced, y, sample_weight=sample_weight)
852-
coefs[i] = local_wlasso.coef_
853-
# Weighted tau
854-
if sample_weight is not None:
855-
y_weighted = y * sample_weight / np.sum(sample_weight)
856-
else:
857-
y_weighted = y / n_samples
858-
tausq[i] = np.dot(y - local_wlasso.predict(X_reduced), y_weighted)
882+
results = Parallel(n_jobs=self.n_jobs)(
883+
delayed(_get_theta_coefs_and_tau_sq)(i, X, sample_weight,
884+
self.alpha_cov, self.n_alphas_cov,
885+
self.max_iter, self.tol, self.random_state)
886+
for i in range(n_features))
887+
coefs, tausq = zip(*results)
888+
coefs = np.array(coefs)
889+
tausq = np.array(tausq)
859890
# Compute C_hat
860891
C_hat = np.diag(np.ones(n_features))
861892
C_hat[0][1:] = -coefs[0]
@@ -893,6 +924,18 @@ class MultiOutputDebiasedLasso(MultiOutputRegressor):
893924
reasons, using ``alpha = 0`` with the ``Lasso`` object is not advised.
894925
Given this, you should use the :class:`LinearRegression` object.
895926
927+
n_alphas : int, optional, default 100
928+
How many alphas to try if alpha='auto'
929+
930+
alpha_cov : string | float, optional, default 'auto'
931+
The regularization alpha that is used when constructing the pseudo inverse of
932+
the covariance matrix Theta used to for correcting the lasso coefficient. Each
933+
such regression corresponds to the regression of one feature on the remainder
934+
of the features.
935+
936+
n_alphas_cov : int, optional, default 10
937+
How many alpha_cov to try if alpha_cov='auto'.
938+
896939
fit_intercept : boolean, optional, default True
897940
Whether to calculate the intercept for this model. If set
898941
to False, no intercept will be used in calculations
@@ -935,6 +978,9 @@ class MultiOutputDebiasedLasso(MultiOutputRegressor):
935978
(setting to 'random') often leads to significantly faster convergence
936979
especially when tol is higher than 1e-4.
937980
981+
n_jobs : int or None, default None
982+
How many jobs to use whenever parallelism is invoked
983+
938984
Attributes
939985
----------
940986
coef_ : array, shape (n_targets, n_features) or (n_features,)
@@ -954,14 +1000,17 @@ class MultiOutputDebiasedLasso(MultiOutputRegressor):
9541000
9551001
"""
9561002

957-
def __init__(self, alpha='auto', fit_intercept=True,
1003+
def __init__(self, alpha='auto', n_alphas=100, alpha_cov='auto', n_alphas_cov=10,
1004+
fit_intercept=True,
9581005
precompute=False, copy_X=True, max_iter=1000,
9591006
tol=1e-4, warm_start=False,
9601007
random_state=None, selection='cyclic', n_jobs=None):
961-
self.estimator = DebiasedLasso(alpha=alpha, fit_intercept=fit_intercept,
1008+
self.estimator = DebiasedLasso(alpha=alpha, n_alphas=n_alphas, alpha_cov=alpha_cov, n_alphas_cov=n_alphas_cov,
1009+
fit_intercept=fit_intercept,
9621010
precompute=precompute, copy_X=copy_X, max_iter=max_iter,
9631011
tol=tol, warm_start=warm_start,
964-
random_state=random_state, selection=selection)
1012+
random_state=random_state, selection=selection,
1013+
n_jobs=n_jobs)
9651014
super().__init__(estimator=self.estimator, n_jobs=n_jobs)
9661015

9671016
def fit(self, X, y, sample_weight=None):

econml/tests/test_dml.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -967,9 +967,10 @@ def test_categories(self):
967967
dmls = [LinearDML, SparseLinearDML]
968968
for ctor in dmls:
969969
dml1 = ctor(LinearRegression(), LogisticRegression(C=1000),
970-
fit_cate_intercept=False, discrete_treatment=True)
970+
fit_cate_intercept=False, discrete_treatment=True, random_state=123)
971971
dml2 = ctor(LinearRegression(), LogisticRegression(C=1000),
972-
fit_cate_intercept=False, discrete_treatment=True, categories=['c', 'b', 'a'])
972+
fit_cate_intercept=False, discrete_treatment=True, categories=['c', 'b', 'a'],
973+
random_state=123)
973974

974975
# create a simple artificial setup where effect of moving from treatment
975976
# a -> b is 2,
@@ -1003,9 +1004,9 @@ def test_categories(self):
10031004
# but const_marginal_effect should be reordered based on the explicit cagetories
10041005
cme1 = dml1.const_marginal_effect(np.ones((1, 1))).reshape(-1)
10051006
cme2 = dml2.const_marginal_effect(np.ones((1, 1))).reshape(-1)
1006-
self.assertAlmostEqual(cme1[1], -cme2[1], places=4) # 1->3 in original ordering; 3->1 in new ordering
1007+
self.assertAlmostEqual(cme1[1], -cme2[1], places=3) # 1->3 in original ordering; 3->1 in new ordering
10071008
# 1-> 2 in original ordering; combination of 3->1 and 3->2
1008-
self.assertAlmostEqual(cme1[0], -cme2[1] + cme2[0], places=4)
1009+
self.assertAlmostEqual(cme1[0], -cme2[1] + cme2[0], places=3)
10091010

10101011
def test_groups(self):
10111012
groups = [1, 2, 3, 4, 5, 6] * 10

0 commit comments

Comments
 (0)