@@ -319,7 +319,7 @@ treatment_effects = est.effect(X_test)
319319```
320320</details >
321321
322- See the <a href =" #references " >References</a > section for more details.
322+ See the <a href =" #references " >References</a > section for more details.
323323
324324### Interpretability
325325<details >
@@ -370,6 +370,54 @@ treatment_effects = est.effect(X_test)
370370
371371</details >
372372
373+
374+ ### Causal Model Selection and Cross-Validation
375+
376+
377+ <details >
378+ <summary >Causal model selection with the `RScorer` (click to expand)</summary >
379+
380+ ``` Python
381+ from econml.score import Rscorer
382+
383+ # split data in train-validation
384+ X_train, X_val, T_train, T_val, Y_train, Y_val = train_test_split(X, T, y, test_size = .4 )
385+
386+ # define list of CATE estimators to select among
387+ reg = lambda : RandomForestRegressor(min_samples_leaf = 20 )
388+ clf = lambda : RandomForestClassifier(min_samples_leaf = 20 )
389+ models = [(' ldml' , LinearDML(model_y = reg(), model_t = clf(), discrete_treatment = True ,
390+ linear_first_stages = False , n_splits = 3 )),
391+ (' xlearner' , XLearner(models = reg(), cate_models = reg(), propensity_model = clf())),
392+ (' dalearner' , DomainAdaptationLearner(models = reg(), final_models = reg(), propensity_model = clf())),
393+ (' slearner' , SLearner(overall_model = reg())),
394+ (' drlearner' , DRLearner(model_propensity = clf(), model_regression = reg(),
395+ model_final = reg(), n_splits = 3 )),
396+ (' rlearner' , NonParamDML(model_y = reg(), model_t = clf(), model_final = reg(),
397+ discrete_treatment = True , n_splits = 3 )),
398+ (' dml3dlasso' , DML(model_y = reg(), model_t = clf(),
399+ model_final = LassoCV(cv = 3 , fit_intercept = False ),
400+ discrete_treatment = True ,
401+ featurizer = PolynomialFeatures(degree = 3 ),
402+ linear_first_stages = False , n_splits = 3 ))
403+ ]
404+
405+ # fit cate models on train data
406+ models = [(name, mdl.fit(Y_train, T_train, X = X_train)) for name, mdl in models]
407+
408+ # score cate models on validation data
409+ scorer = RScorer(model_y = reg(), model_t = clf(),
410+ discrete_treatment = True , n_splits = 3 , mc_iters = 2 , mc_agg = ' median' )
411+ scorer.fit(Y_val, T_val, X = X_val)
412+ rscore = [scorer.score(mdl) for _, mdl in models]
413+ # select the best model
414+ mdl, _ = scorer.best_model([mdl for _, mdl in models])
415+ # create weighted ensemble model based on score performance
416+ mdl, _ = scorer.ensemble([mdl for _, mdl in models])
417+ ```
418+
419+ </details >
420+
373421### Inference
374422
375423Whenever inference is enabled, then one can get a more structure ` InferenceResults ` object with more elaborate inference information, such
0 commit comments