Skip to content

Commit 03a572f

Browse files
Merge pull request #386 from DoubleML/o-blp-multirep
Enable CATEs and GATEs for multiple repetitions
2 parents cb963d0 + f332a3a commit 03a572f

File tree

12 files changed

+336
-99
lines changed

12 files changed

+336
-99
lines changed

doubleml/irm/apo.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,9 @@ def capo(self, basis, is_gate=False, **kwargs):
579579
if self.score not in valid_score:
580580
raise ValueError("Invalid score " + self.score + ". " + "Valid score " + " or ".join(valid_score) + ".")
581581

582-
if self.n_rep != 1:
583-
raise NotImplementedError("Only implemented for one repetition. " + f"Number of repetitions is {str(self.n_rep)}.")
584-
585582
# define the orthogonal signal
586-
orth_signal = self.psi_elements["psi_b"].reshape(-1)
583+
orth_signal = np.squeeze(self.psi_elements["psi_b"], axis=2)
584+
587585
# fit the best linear predictor
588586
model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
589587
model.fit(**kwargs)

doubleml/irm/irm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,11 +587,9 @@ def cate(self, basis, is_gate=False, **kwargs):
587587
if self.score not in valid_score:
588588
raise ValueError("Invalid score " + self.score + ". " + "Valid score " + " or ".join(valid_score) + ".")
589589

590-
if self.n_rep != 1:
591-
raise NotImplementedError("Only implemented for one repetition. " + f"Number of repetitions is {str(self.n_rep)}.")
592-
593590
# define the orthogonal signal
594-
orth_signal = self.psi_elements["psi_b"].reshape(-1)
591+
orth_signal = np.squeeze(self.psi_elements["psi_b"], axis=2)
592+
595593
# fit the best linear predictor
596594
model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
597595
model.fit(**kwargs)

doubleml/irm/tests/test_apo.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test_dml_apo_capo_gapo(treatment_level, cov_type):
257257
capo = dml_obj.capo(random_basis, cov_type=cov_type)
258258
assert isinstance(capo, dml.utils.blp.DoubleMLBLP)
259259
assert isinstance(capo.confint(), pd.DataFrame)
260-
assert capo.blp_model.cov_type == cov_type
260+
assert capo.blp_model[0].cov_type == cov_type
261261

262262
groups_1 = pd.DataFrame(
263263
np.column_stack([obj_dml_data.data["X1"] <= -1.0, obj_dml_data.data["X1"] > 0.2]), columns=["Group 1", "Group 2"]
@@ -268,7 +268,7 @@ def test_dml_apo_capo_gapo(treatment_level, cov_type):
268268
assert isinstance(gapo_1, dml.utils.blp.DoubleMLBLP)
269269
assert isinstance(gapo_1.confint(), pd.DataFrame)
270270
assert all(gapo_1.confint().index == groups_1.columns.to_list())
271-
assert gapo_1.blp_model.cov_type == cov_type
271+
assert gapo_1.blp_model[0].cov_type == cov_type
272272

273273
np.random.seed(42)
274274
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n, p=[0.1, 0.9]))
@@ -278,4 +278,49 @@ def test_dml_apo_capo_gapo(treatment_level, cov_type):
278278
assert isinstance(gapo_2, dml.utils.blp.DoubleMLBLP)
279279
assert isinstance(gapo_2.confint(), pd.DataFrame)
280280
assert all(gapo_2.confint().index == ["Group_1", "Group_2"])
281-
assert gapo_2.blp_model.cov_type == cov_type
281+
assert gapo_2.blp_model[0].cov_type == cov_type
282+
283+
284+
@pytest.mark.ci
285+
def test_dml_apo_capo_gapo_multiple_rep(treatment_level, cov_type):
286+
n = 120
287+
np.random.seed(42)
288+
obj_dml_data = make_irm_data(n_obs=n, dim_x=2)
289+
290+
ml_g = RandomForestRegressor(n_estimators=10, random_state=42)
291+
ml_m = RandomForestClassifier(n_estimators=10, random_state=42)
292+
293+
dml_obj = dml.DoubleMLAPO(
294+
obj_dml_data,
295+
ml_m=ml_m,
296+
ml_g=ml_g,
297+
treatment_level=treatment_level,
298+
ps_processor_config=PSProcessorConfig(clipping_threshold=0.05),
299+
n_folds=3,
300+
n_rep=2,
301+
)
302+
303+
dml_obj.fit()
304+
305+
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
306+
capo = dml_obj.capo(random_basis, cov_type=cov_type)
307+
assert isinstance(capo, dml.utils.blp.DoubleMLBLP)
308+
assert capo.n_rep == 2
309+
assert isinstance(capo.blp_model, list)
310+
assert len(capo.blp_model) == 2
311+
assert capo.blp_model[0].cov_type == cov_type
312+
assert capo.blp_model[1].cov_type == cov_type
313+
assert capo.all_coef.shape == (random_basis.shape[1], 2)
314+
assert capo.all_se.shape == (random_basis.shape[1], 2)
315+
assert isinstance(capo.confint(), pd.DataFrame)
316+
assert isinstance(capo.summary, pd.DataFrame)
317+
318+
x1 = obj_dml_data.data["X1"]
319+
groups = pd.DataFrame({"Group 1": x1 <= x1.median(), "Group 2": x1 > x1.median()})
320+
gapo = dml_obj.gapo(groups, cov_type=cov_type)
321+
assert isinstance(gapo, dml.utils.blp.DoubleMLBLP)
322+
assert gapo.n_rep == 2
323+
assert gapo.all_coef.shape == (groups.shape[1], 2)
324+
assert gapo.all_se.shape == (groups.shape[1], 2)
325+
assert isinstance(gapo.confint(), pd.DataFrame)
326+
assert all(gapo.confint().index == groups.columns.to_list())

doubleml/irm/tests/test_apo_exceptions.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,6 @@ def test_apo_exception_capo_gapo():
202202
# reset the score
203203
dml_obj._score = "APO"
204204

205-
msg = "Only implemented for one repetition. Number of repetitions is 2."
206-
with pytest.raises(NotImplementedError, match=msg):
207-
dml_obj._n_rep = 2
208-
dml_obj.capo(random_basis)
209-
# reset the number of repetitions
210-
dml_obj._n_rep = 1
211-
212205
msg = "Groups must be of DataFrame type. Groups of type <class 'int'> was passed."
213206
with pytest.raises(TypeError, match=msg):
214207
_ = dml_obj.gapo(1)

doubleml/irm/tests/test_irm.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def test_dml_irm_cate_gate(cov_type):
246246
cate = dml_irm_obj.cate(random_basis, cov_type=cov_type)
247247
assert isinstance(cate, dml.utils.blp.DoubleMLBLP)
248248
assert isinstance(cate.confint(), pd.DataFrame)
249-
assert cate.blp_model.cov_type == cov_type
249+
assert cate.blp_model[0].cov_type == cov_type
250250

251251
groups_1 = pd.DataFrame(
252252
np.column_stack([obj_dml_data.data["X1"] <= 0, obj_dml_data.data["X1"] > 0.2]), columns=["Group 1", "Group 2"]
@@ -257,7 +257,7 @@ def test_dml_irm_cate_gate(cov_type):
257257
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
258258
assert isinstance(gate_1.confint(), pd.DataFrame)
259259
assert all(gate_1.confint().index == groups_1.columns.to_list())
260-
assert gate_1.blp_model.cov_type == cov_type
260+
assert gate_1.blp_model[0].cov_type == cov_type
261261

262262
np.random.seed(42)
263263
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
@@ -267,7 +267,50 @@ def test_dml_irm_cate_gate(cov_type):
267267
assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP)
268268
assert isinstance(gate_2.confint(), pd.DataFrame)
269269
assert all(gate_2.confint().index == ["Group_1", "Group_2"])
270-
assert gate_2.blp_model.cov_type == cov_type
270+
assert gate_2.blp_model[0].cov_type == cov_type
271+
272+
273+
@pytest.mark.ci
274+
def test_dml_irm_cate_gate_multiple_rep(cov_type):
275+
n = 120
276+
np.random.seed(42)
277+
obj_dml_data = make_irm_data(n_obs=n, dim_x=2)
278+
279+
ml_g = RandomForestRegressor(n_estimators=10, random_state=42)
280+
ml_m = RandomForestClassifier(n_estimators=10, random_state=42)
281+
ps_processor_config = PSProcessorConfig(clipping_threshold=0.05)
282+
dml_irm_obj = dml.DoubleMLIRM(
283+
obj_dml_data,
284+
ml_m=ml_m,
285+
ml_g=ml_g,
286+
ps_processor_config=ps_processor_config,
287+
n_folds=3,
288+
n_rep=2,
289+
)
290+
291+
dml_irm_obj.fit()
292+
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
293+
cate = dml_irm_obj.cate(random_basis, cov_type=cov_type)
294+
assert isinstance(cate, dml.utils.blp.DoubleMLBLP)
295+
assert cate.n_rep == 2
296+
assert isinstance(cate.blp_model, list)
297+
assert len(cate.blp_model) == 2
298+
assert cate.blp_model[0].cov_type == cov_type
299+
assert cate.blp_model[1].cov_type == cov_type
300+
assert cate.all_coef.shape == (random_basis.shape[1], 2)
301+
assert cate.all_se.shape == (random_basis.shape[1], 2)
302+
assert isinstance(cate.confint(), pd.DataFrame)
303+
assert isinstance(cate.summary, pd.DataFrame)
304+
305+
x1 = obj_dml_data.data["X1"]
306+
groups = pd.DataFrame({"Group 1": x1 <= x1.median(), "Group 2": x1 > x1.median()})
307+
gate = dml_irm_obj.gate(groups, cov_type=cov_type)
308+
assert isinstance(gate, dml.utils.blp.DoubleMLBLP)
309+
assert gate.n_rep == 2
310+
assert gate.all_coef.shape == (groups.shape[1], 2)
311+
assert gate.all_se.shape == (groups.shape[1], 2)
312+
assert isinstance(gate.confint(), pd.DataFrame)
313+
assert all(gate.confint().index == groups.columns.to_list())
271314

272315

273316
@pytest.fixture(scope="module", params=[1, 3])

doubleml/plm/plr.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,14 +470,12 @@ def cate(self, basis, is_gate=False, **kwargs):
470470
raise NotImplementedError(
471471
"Only implemented for single treatment. " + f"Number of treatments is {str(self._dml_data.n_treat)}."
472472
)
473-
if self.n_rep != 1:
474-
raise NotImplementedError("Only implemented for one repetition. " + f"Number of repetitions is {str(self.n_rep)}.")
475473

476474
Y_tilde, D_tilde = self._partial_out()
477475

478476
D_basis = basis * D_tilde
479477
model = DoubleMLBLP(
480-
orth_signal=Y_tilde.reshape(-1),
478+
orth_signal=Y_tilde,
481479
basis=D_basis,
482480
is_gate=is_gate,
483481
)

doubleml/plm/tests/test_plr.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_dml_plr_cate_gate(score, cov_type):
315315
cate = dml_plr_obj.cate(random_basis, cov_type=cov_type)
316316
assert isinstance(cate, dml.DoubleMLBLP)
317317
assert isinstance(cate.confint(), pd.DataFrame)
318-
assert cate.blp_model.cov_type == cov_type
318+
assert cate.blp_model[0].cov_type == cov_type
319319

320320
groups_1 = pd.DataFrame(
321321
np.column_stack([obj_dml_data.data["X1"] <= 0, obj_dml_data.data["X1"] > 0.2]), columns=["Group 1", "Group 2"]
@@ -326,7 +326,7 @@ def test_dml_plr_cate_gate(score, cov_type):
326326
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
327327
assert isinstance(gate_1.confint(), pd.DataFrame)
328328
assert all(gate_1.confint().index == groups_1.columns.tolist())
329-
assert gate_1.blp_model.cov_type == cov_type
329+
assert gate_1.blp_model[0].cov_type == cov_type
330330

331331
np.random.seed(42)
332332
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
@@ -336,4 +336,46 @@ def test_dml_plr_cate_gate(score, cov_type):
336336
assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP)
337337
assert isinstance(gate_2.confint(), pd.DataFrame)
338338
assert all(gate_2.confint().index == ["Group_1", "Group_2"])
339-
assert gate_2.blp_model.cov_type == cov_type
339+
assert gate_2.blp_model[0].cov_type == cov_type
340+
341+
342+
@pytest.mark.ci
343+
def test_dml_plr_cate_gate_multiple_rep(score, cov_type):
344+
n = 120
345+
346+
np.random.seed(42)
347+
obj_dml_data = dml.plm.datasets.make_plr_CCDDHNR2018(n_obs=n)
348+
ml_l = LinearRegression()
349+
ml_g = LinearRegression()
350+
ml_m = LinearRegression()
351+
352+
if score == "partialling out":
353+
dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_l=ml_l, ml_m=ml_m, n_folds=3, n_rep=2, score=score)
354+
else:
355+
assert score == "IV-type"
356+
dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_l=ml_l, ml_m=ml_m, ml_g=ml_g, n_folds=3, n_rep=2, score=score)
357+
358+
dml_plr_obj.fit()
359+
360+
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 2)))
361+
cate = dml_plr_obj.cate(random_basis, cov_type=cov_type)
362+
assert isinstance(cate, dml.DoubleMLBLP)
363+
assert cate.n_rep == 2
364+
assert isinstance(cate.blp_model, list)
365+
assert len(cate.blp_model) == 2
366+
assert cate.blp_model[0].cov_type == cov_type
367+
assert cate.blp_model[1].cov_type == cov_type
368+
assert cate.all_coef.shape == (random_basis.shape[1], 2)
369+
assert cate.all_se.shape == (random_basis.shape[1], 2)
370+
assert isinstance(cate.confint(), pd.DataFrame)
371+
assert isinstance(cate.summary, pd.DataFrame)
372+
373+
x1 = obj_dml_data.data["X1"]
374+
groups = pd.DataFrame({"Group 1": x1 <= x1.median(), "Group 2": x1 > x1.median()})
375+
gate = dml_plr_obj.gate(groups, cov_type=cov_type)
376+
assert isinstance(gate, dml.DoubleMLBLP)
377+
assert gate.n_rep == 2
378+
assert gate.all_coef.shape == (groups.shape[1], 2)
379+
assert gate.all_se.shape == (groups.shape[1], 2)
380+
assert isinstance(gate.confint(), pd.DataFrame)
381+
assert all(gate.confint().index == groups.columns.tolist())

doubleml/plm/tests/test_plr_binary_outcome.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def test_dml_plr_binary_cate_gate(score, cov_type, generate_binary_data):
231231
cate = dml_plr_obj.cate(random_basis, cov_type=cov_type)
232232
assert isinstance(cate, dml.DoubleMLBLP)
233233
assert isinstance(cate.confint(), pd.DataFrame)
234-
assert cate.blp_model.cov_type == cov_type
234+
assert cate.blp_model[0].cov_type == cov_type
235235

236236
groups_1 = pd.DataFrame(np.column_stack([data["X1"] <= 0, data["X1"] > 0.2]), columns=["Group 1", "Group 2"])
237237
msg = "At least one group effect is estimated with less than 6 observations."
@@ -240,4 +240,4 @@ def test_dml_plr_binary_cate_gate(score, cov_type, generate_binary_data):
240240
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
241241
assert isinstance(gate_1.confint(), pd.DataFrame)
242242
assert all(gate_1.confint().index == groups_1.columns.tolist())
243-
assert gate_1.blp_model.cov_type == cov_type
243+
assert gate_1.blp_model[0].cov_type == cov_type

doubleml/tests/test_exceptions.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,10 +1387,9 @@ def test_doubleml_exception_gate():
13871387
n_rep=2,
13881388
)
13891389
dml_irm_obj.fit()
1390-
1391-
msg = "Only implemented for one repetition. Number of repetitions is 2."
1392-
with pytest.raises(NotImplementedError, match=msg):
1393-
dml_irm_obj.gate(groups=groups)
1390+
msg = "Groups must be of DataFrame type. Groups of type <class 'int'> was passed."
1391+
with pytest.raises(TypeError, match=msg):
1392+
dml_irm_obj.gate(groups=2)
13941393

13951394

13961395
@pytest.mark.ci
@@ -1419,17 +1418,17 @@ def test_doubleml_exception_cate():
14191418
n_rep=2,
14201419
)
14211420
dml_irm_obj.fit()
1422-
msg = "Only implemented for one repetition. Number of repetitions is 2."
1423-
with pytest.raises(NotImplementedError, match=msg):
1421+
msg = "The basis must be of DataFrame type. Basis of type <class 'int'> was passed."
1422+
with pytest.raises(TypeError, match=msg):
14241423
dml_irm_obj.cate(basis=2)
14251424

14261425

14271426
@pytest.mark.ci
14281427
def test_doubleml_exception_plr_cate():
14291428
dml_plr_obj = DoubleMLPLR(dml_data, ml_l=Lasso(), ml_m=Lasso(), n_folds=2, n_rep=2)
14301429
dml_plr_obj.fit()
1431-
msg = "Only implemented for one repetition. Number of repetitions is 2."
1432-
with pytest.raises(NotImplementedError, match=msg):
1430+
msg = "The basis must be of DataFrame type. Basis of type <class 'numpy.ndarray'> was passed."
1431+
with pytest.raises(TypeError, match=msg):
14331432
dml_plr_obj.cate(basis=2)
14341433

14351434
dml_plr_obj = DoubleMLPLR(dml_data, ml_l=Lasso(), ml_m=Lasso(), n_folds=2)
@@ -1460,6 +1459,12 @@ def test_doubleml_exception_plr_gate():
14601459
with pytest.raises(TypeError, match=msg):
14611460
dml_plr_obj.gate(groups=pd.DataFrame(np.random.normal(0, 1, size=(dml_data.n_obs, 3))))
14621461

1462+
dml_plr_obj = DoubleMLPLR(dml_data, ml_l=Lasso(), ml_m=Lasso(), n_folds=2, n_rep=2)
1463+
dml_plr_obj.fit()
1464+
msg = "Groups must be of DataFrame type. Groups of type <class 'int'> was passed."
1465+
with pytest.raises(TypeError, match=msg):
1466+
dml_plr_obj.gate(groups=2)
1467+
14631468

14641469
@pytest.mark.ci
14651470
def test_double_ml_exception_evaluate_learner():

0 commit comments

Comments
 (0)