@@ -315,7 +315,7 @@ def test_dml_plr_cate_gate(score, cov_type):
315315 cate = dml_plr_obj .cate (random_basis , cov_type = cov_type )
316316 assert isinstance (cate , dml .DoubleMLBLP )
317317 assert isinstance (cate .confint (), pd .DataFrame )
318- assert cate .blp_model .cov_type == cov_type
318+ assert cate .blp_model [ 0 ] .cov_type == cov_type
319319
320320 groups_1 = pd .DataFrame (
321321 np .column_stack ([obj_dml_data .data ["X1" ] <= 0 , obj_dml_data .data ["X1" ] > 0.2 ]), columns = ["Group 1" , "Group 2" ]
@@ -326,7 +326,7 @@ def test_dml_plr_cate_gate(score, cov_type):
326326 assert isinstance (gate_1 , dml .utils .blp .DoubleMLBLP )
327327 assert isinstance (gate_1 .confint (), pd .DataFrame )
328328 assert all (gate_1 .confint ().index == groups_1 .columns .tolist ())
329- assert gate_1 .blp_model .cov_type == cov_type
329+ assert gate_1 .blp_model [ 0 ] .cov_type == cov_type
330330
331331 np .random .seed (42 )
332332 groups_2 = pd .DataFrame (np .random .choice (["1" , "2" ], n ))
@@ -336,4 +336,46 @@ def test_dml_plr_cate_gate(score, cov_type):
336336 assert isinstance (gate_2 , dml .utils .blp .DoubleMLBLP )
337337 assert isinstance (gate_2 .confint (), pd .DataFrame )
338338 assert all (gate_2 .confint ().index == ["Group_1" , "Group_2" ])
339- assert gate_2 .blp_model .cov_type == cov_type
339+ assert gate_2 .blp_model [0 ].cov_type == cov_type
340+
341+
342+ @pytest .mark .ci
343+ def test_dml_plr_cate_gate_multiple_rep (score , cov_type ):
344+ n = 120
345+
346+ np .random .seed (42 )
347+ obj_dml_data = dml .plm .datasets .make_plr_CCDDHNR2018 (n_obs = n )
348+ ml_l = LinearRegression ()
349+ ml_g = LinearRegression ()
350+ ml_m = LinearRegression ()
351+
352+ if score == "partialling out" :
353+ dml_plr_obj = dml .DoubleMLPLR (obj_dml_data , ml_l = ml_l , ml_m = ml_m , n_folds = 3 , n_rep = 2 , score = score )
354+ else :
355+ assert score == "IV-type"
356+ dml_plr_obj = dml .DoubleMLPLR (obj_dml_data , ml_l = ml_l , ml_m = ml_m , ml_g = ml_g , n_folds = 3 , n_rep = 2 , score = score )
357+
358+ dml_plr_obj .fit ()
359+
360+ random_basis = pd .DataFrame (np .random .normal (0 , 1 , size = (n , 2 )))
361+ cate = dml_plr_obj .cate (random_basis , cov_type = cov_type )
362+ assert isinstance (cate , dml .DoubleMLBLP )
363+ assert cate .n_rep == 2
364+ assert isinstance (cate .blp_model , list )
365+ assert len (cate .blp_model ) == 2
366+ assert cate .blp_model [0 ].cov_type == cov_type
367+ assert cate .blp_model [1 ].cov_type == cov_type
368+ assert cate .all_coef .shape == (random_basis .shape [1 ], 2 )
369+ assert cate .all_se .shape == (random_basis .shape [1 ], 2 )
370+ assert isinstance (cate .confint (), pd .DataFrame )
371+ assert isinstance (cate .summary , pd .DataFrame )
372+
373+ x1 = obj_dml_data .data ["X1" ]
374+ groups = pd .DataFrame ({"Group 1" : x1 <= x1 .median (), "Group 2" : x1 > x1 .median ()})
375+ gate = dml_plr_obj .gate (groups , cov_type = cov_type )
376+ assert isinstance (gate , dml .DoubleMLBLP )
377+ assert gate .n_rep == 2
378+ assert gate .all_coef .shape == (groups .shape [1 ], 2 )
379+ assert gate .all_se .shape == (groups .shape [1 ], 2 )
380+ assert isinstance (gate .confint (), pd .DataFrame )
381+ assert all (gate .confint ().index == groups .columns .tolist ())
0 commit comments