Skip to content

Commit 1205430

Browse files
Merge pull request #330 from EgorKraevTransferwise/experiment_plotting
Experiment plotting improvements
2 parents 346aa84 + 5627d9a commit 1205430

File tree

6 files changed

+302
-226
lines changed

6 files changed

+302
-226
lines changed

causaltune/optimiser.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from joblib import Parallel, delayed
2121

2222
from causaltune.search.params import SimpleParamService
23-
from causaltune.score.scoring import Scorer
23+
from causaltune.score.scoring import Scorer, metrics_to_minimize
2424
from causaltune.utils import treatment_is_multivalue
2525
from causaltune.models.monkey_patches import (
2626
AutoML,
@@ -514,19 +514,7 @@ def fit(
514514
evaluated_rewards=(
515515
[] if len(self.resume_scores) == 0 else self.resume_scores
516516
),
517-
mode=(
518-
"min"
519-
if self.metric
520-
in [
521-
"energy_distance",
522-
"psw_energy_distance",
523-
"frobenius_norm",
524-
"psw_frobenius_norm",
525-
"codec",
526-
"policy_risk",
527-
]
528-
else "max"
529-
),
517+
mode=("min" if self.metric in metrics_to_minimize() else "max"),
530518
low_cost_partial_config={},
531519
**self._settings["tuner"],
532520
)
@@ -547,19 +535,7 @@ def fit(
547535
evaluated_rewards=(
548536
[] if len(self.resume_scores) == 0 else self.resume_scores
549537
),
550-
mode=(
551-
"min"
552-
if self.metric
553-
in [
554-
"energy_distance",
555-
"psw_energy_distance",
556-
"frobenius_norm",
557-
"psw_frobenius_norm",
558-
"codec",
559-
"policy_risk",
560-
]
561-
else "max"
562-
),
538+
mode=("min" if self.metric in metrics_to_minimize() else "max"),
563539
low_cost_partial_config={},
564540
**self._settings["tuner"],
565541
)

causaltune/score/erupt_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def erupt_with_std(
5050
]
5151
mean += np.mean(means)
5252
std += np.std(means) / np.sqrt(num_splits) # Standard error of the mean
53-
54-
return mean / resamples, std / resamples
53+
# 1.5 is an empirical factor to make the confidence interval wider
54+
return mean / resamples, 1.5 * std / resamples
5555

5656

5757
def erupt(

causaltune/score/scoring.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def const_marginal_effect(self, X):
3737
return self.cate_estimate
3838

3939

40-
def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List[str]:
40+
def supported_metrics(problem: str, multivalue: bool, scores_only: bool, constant_ptt: bool=False) -> List[str]:
4141
if problem == "iv":
4242
metrics = ["energy_distance", "frobenius_norm", "codec"]
4343
if not scores_only:
@@ -52,12 +52,12 @@ def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List
5252
metrics = [
5353
"erupt",
5454
"norm_erupt",
55-
"greedy_erupt", # regular erupt was made probabilistic, no need for a separate one
55+
# "greedy_erupt", # regular erupt was made probabilistic, no need for a separate one
5656
"policy_risk", # NEW
5757
"qini",
5858
"auc",
5959
# "r_scorer",
60-
"energy_distance", # is broken without propensity weighting
60+
"energy_distance", # should only be used in iv problems
6161
"psw_energy_distance",
6262
"frobenius_norm", # NEW
6363
"codec", # NEW
@@ -68,6 +68,17 @@ def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List
6868
return metrics
6969

7070

71+
def metrics_to_minimize():
72+
return [
73+
"energy_distance",
74+
"psw_energy_distance",
75+
"codec",
76+
"frobenius_norm",
77+
"psw_frobenius_norm",
78+
"policy_risk",
79+
]
80+
81+
7182
class Scorer:
7283
def __init__(
7384
self,
@@ -90,6 +101,10 @@ def __init__(
90101
self.identified_estimand = causal_model.identify_effect(
91102
proceed_when_unidentifiable=True
92103
)
104+
if "Dummy" in propensity_model.__class__.__name__:
105+
self.constant_ptt = True
106+
else:
107+
self.constant_ptt = False
93108

94109
if problem == "backdoor":
95110
print(

causaltune/search/component.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def joint_config(data_size: Tuple[int, int], estimator_list=None):
6161
cfg, init_params, low_cost_init_params = flaml_config_to_tune_config(
6262
cls.search_space(data_size=data_size, task=task)
6363
)
64-
64+
cfg, init_params = tweak_config(cfg, init_params, name)
6565
# Test if the estimator instantiates fine
6666
try:
6767
cls(task=task, **init_params)
@@ -76,6 +76,25 @@ def joint_config(data_size: Tuple[int, int], estimator_list=None):
7676
return tune.choice(joint_cfg), joint_init_params, joint_low_cost_init_params
7777

7878

79+
def tweak_config(cfg: dict, init_params: dict, estimator_name: str):
80+
"""
81+
Tweak built-in FLAML search spaces to limit the number of estimators
82+
:param cfg:
83+
:param estimator_name:
84+
:return:
85+
"""
86+
out = copy.deepcopy(cfg)
87+
if "xgboost" in estimator_name or estimator_name in [
88+
"random_forest",
89+
"extra_trees",
90+
"lgbm",
91+
"catboost",
92+
]:
93+
out["n_estimators"] = tune.lograndint(4, 1000)
94+
init_params["n_estimators"] = 100
95+
return out, init_params
96+
97+
7998
def model_from_cfg(cfg: dict):
8099
cfg = copy.deepcopy(cfg)
81100
model_name = cfg.pop("estimator_name")

notebooks/ERUPT basics.ipynb

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -103,45 +103,45 @@
103103
" <tbody>\n",
104104
" <tr>\n",
105105
" <th>0</th>\n",
106-
" <td>0.452636</td>\n",
107-
" <td>0</td>\n",
108-
" <td>1.684484</td>\n",
106+
" <td>0.898227</td>\n",
107+
" <td>1</td>\n",
108+
" <td>1.288637</td>\n",
109109
" </tr>\n",
110110
" <tr>\n",
111111
" <th>1</th>\n",
112-
" <td>0.380215</td>\n",
112+
" <td>0.462092</td>\n",
113113
" <td>0</td>\n",
114-
" <td>0.745268</td>\n",
114+
" <td>0.771976</td>\n",
115115
" </tr>\n",
116116
" <tr>\n",
117117
" <th>2</th>\n",
118-
" <td>0.584036</td>\n",
119-
" <td>1</td>\n",
120-
" <td>0.762300</td>\n",
118+
" <td>0.858974</td>\n",
119+
" <td>0</td>\n",
120+
" <td>1.881019</td>\n",
121121
" </tr>\n",
122122
" <tr>\n",
123123
" <th>3</th>\n",
124-
" <td>0.505191</td>\n",
125-
" <td>0</td>\n",
126-
" <td>1.425354</td>\n",
124+
" <td>0.228084</td>\n",
125+
" <td>1</td>\n",
126+
" <td>0.357797</td>\n",
127127
" </tr>\n",
128128
" <tr>\n",
129129
" <th>4</th>\n",
130-
" <td>0.384110</td>\n",
130+
" <td>0.962512</td>\n",
131131
" <td>1</td>\n",
132-
" <td>1.834628</td>\n",
132+
" <td>1.066413</td>\n",
133133
" </tr>\n",
134134
" </tbody>\n",
135135
"</table>\n",
136136
"</div>"
137137
],
138138
"text/plain": [
139139
" X T1 Y1\n",
140-
"0 0.452636 0 1.684484\n",
141-
"1 0.380215 0 0.745268\n",
142-
"2 0.584036 1 0.762300\n",
143-
"3 0.505191 0 1.425354\n",
144-
"4 0.384110 1 1.834628"
140+
"0 0.898227 1 1.288637\n",
141+
"1 0.462092 0 0.771976\n",
142+
"2 0.858974 0 1.881019\n",
143+
"3 0.228084 1 0.357797\n",
144+
"4 0.962512 1 1.066413"
145145
]
146146
},
147147
"execution_count": 2,
@@ -216,65 +216,65 @@
216216
" <tbody>\n",
217217
" <tr>\n",
218218
" <th>0</th>\n",
219-
" <td>0.452636</td>\n",
220-
" <td>0</td>\n",
221-
" <td>1.684484</td>\n",
222-
" <td>0.726318</td>\n",
223-
" <td>0</td>\n",
224-
" <td>0.273682</td>\n",
225-
" <td>0.904259</td>\n",
219+
" <td>0.898227</td>\n",
220+
" <td>1</td>\n",
221+
" <td>1.288637</td>\n",
222+
" <td>0.949114</td>\n",
223+
" <td>1</td>\n",
224+
" <td>0.949114</td>\n",
225+
" <td>2.229118</td>\n",
226226
" </tr>\n",
227227
" <tr>\n",
228228
" <th>1</th>\n",
229-
" <td>0.380215</td>\n",
229+
" <td>0.462092</td>\n",
230230
" <td>0</td>\n",
231-
" <td>0.745268</td>\n",
232-
" <td>0.690108</td>\n",
233-
" <td>1</td>\n",
234-
" <td>0.690108</td>\n",
235-
" <td>1.930383</td>\n",
231+
" <td>0.771976</td>\n",
232+
" <td>0.731046</td>\n",
233+
" <td>0</td>\n",
234+
" <td>0.268954</td>\n",
235+
" <td>0.572308</td>\n",
236236
" </tr>\n",
237237
" <tr>\n",
238238
" <th>2</th>\n",
239-
" <td>0.584036</td>\n",
240-
" <td>1</td>\n",
241-
" <td>0.762300</td>\n",
242-
" <td>0.792018</td>\n",
239+
" <td>0.858974</td>\n",
240+
" <td>0</td>\n",
241+
" <td>1.881019</td>\n",
242+
" <td>0.929487</td>\n",
243243
" <td>1</td>\n",
244-
" <td>0.792018</td>\n",
245-
" <td>0.959608</td>\n",
244+
" <td>0.929487</td>\n",
245+
" <td>2.601592</td>\n",
246246
" </tr>\n",
247247
" <tr>\n",
248248
" <th>3</th>\n",
249-
" <td>0.505191</td>\n",
250-
" <td>0</td>\n",
251-
" <td>1.425354</td>\n",
252-
" <td>0.752596</td>\n",
249+
" <td>0.228084</td>\n",
250+
" <td>1</td>\n",
251+
" <td>0.357797</td>\n",
252+
" <td>0.614042</td>\n",
253253
" <td>1</td>\n",
254-
" <td>0.752596</td>\n",
255-
" <td>1.017777</td>\n",
254+
" <td>0.614042</td>\n",
255+
" <td>0.542638</td>\n",
256256
" </tr>\n",
257257
" <tr>\n",
258258
" <th>4</th>\n",
259-
" <td>0.384110</td>\n",
259+
" <td>0.962512</td>\n",
260260
" <td>1</td>\n",
261-
" <td>1.834628</td>\n",
262-
" <td>0.692055</td>\n",
261+
" <td>1.066413</td>\n",
262+
" <td>0.981256</td>\n",
263263
" <td>1</td>\n",
264-
" <td>0.692055</td>\n",
265-
" <td>2.374030</td>\n",
264+
" <td>0.981256</td>\n",
265+
" <td>2.401383</td>\n",
266266
" </tr>\n",
267267
" </tbody>\n",
268268
"</table>\n",
269269
"</div>"
270270
],
271271
"text/plain": [
272272
" X T1 Y1 p T2 p_of_actual Y2\n",
273-
"0 0.452636 0 1.684484 0.726318 0 0.273682 0.904259\n",
274-
"1 0.380215 0 0.745268 0.690108 1 0.690108 1.930383\n",
275-
"2 0.584036 1 0.762300 0.792018 1 0.792018 0.959608\n",
276-
"3 0.505191 0 1.425354 0.752596 1 0.752596 1.017777\n",
277-
"4 0.384110 1 1.834628 0.692055 1 0.692055 2.374030"
273+
"0 0.898227 1 1.288637 0.949114 1 0.949114 2.229118\n",
274+
"1 0.462092 0 0.771976 0.731046 0 0.268954 0.572308\n",
275+
"2 0.858974 0 1.881019 0.929487 1 0.929487 2.601592\n",
276+
"3 0.228084 1 0.357797 0.614042 1 0.614042 0.542638\n",
277+
"4 0.962512 1 1.066413 0.981256 1 0.981256 2.401383"
278278
]
279279
},
280280
"execution_count": 3,
@@ -319,10 +319,10 @@
319319
"name": "stdout",
320320
"output_type": "stream",
321321
"text": [
322-
"Average outcome of the actual biased assignment experiment: 1.411675477573636\n",
323-
"Estimated outcome of random assignment: 1.251567372523789\n",
324-
"95% confidence interval for estimated outcome: 1.2311928820519622 1.2719418629956158\n",
325-
"Average outcome of the actual random assignment experiment: 1.2559621877416332\n"
322+
"Average outcome of the actual biased assignment experiment: 1.4064676444383317\n",
323+
"Estimated outcome of random assignment: 1.2594221770638483\n",
324+
"95% confidence interval for estimated outcome: 1.230204391668238 1.2886399624594587\n",
325+
"Average outcome of the actual random assignment experiment: 1.2461659092712785\n"
326326
]
327327
}
328328
],
@@ -360,17 +360,17 @@
360360
"name": "stdout",
361361
"output_type": "stream",
362362
"text": [
363-
"Average outcome of the actual random assignment experiment: 1.2559621877416332\n",
364-
"Estimated outcome of biased assignment: 1.4147647990746988\n",
365-
"Confidence interval for estimated outcome: 1.398423601541284 1.4311059966081134\n",
366-
"Average outcome of the actual biased assignment experiment: 1.411675477573636\n"
363+
"Average outcome of the actual random assignment experiment: 1.2461659092712785\n",
364+
"Estimated outcome of biased assignment: 1.405112521603215\n",
365+
"95% confidence interval for estimated outcome: 1.3814865905561569 1.428738452650273\n",
366+
"Average outcome of the actual biased assignment experiment: 1.4064676444383317\n"
367367
]
368368
}
369369
],
370370
"source": [
371-
"# Conversely, we can take the outcome of the fully random test and use it to estimate what the outcome of the biased assignment would have been\n",
371+
"# Conversely, we can take the outcome of the fully random test and use it \n",
372+
"# to estimate what the outcome of the biased assignment would have been\n",
372373
"\n",
373-
"# Let's use data from biased assignment experiment to estimate the average effect of fully random assignment\n",
374374
"hypothetical_policy = df[\"T2\"]\n",
375375
"est, std = erupt_with_std(actual_propensity=0.5*pd.Series(np.ones(len(df))), \n",
376376
" actual_treatment=df[\"T1\"],\n",
@@ -379,7 +379,7 @@
379379
"\n",
380380
"print(\"Average outcome of the actual random assignment experiment:\", df[\"Y1\"].mean())\n",
381381
"print(\"Estimated outcome of biased assignment:\", est)\n",
382-
"print(\"Confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n",
382+
"print(\"95% confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n",
383383
"print(\"Average outcome of the actual biased assignment experiment:\", df[\"Y2\"].mean())"
384384
]
385385
},
@@ -388,7 +388,7 @@
388388
"id": "f724dbc3",
389389
"metadata": {},
390390
"source": [
391-
"As you can see, the actual outcome is well within the confidence interval estimated by ERUPT"
391+
"As you can see, the actual outcome is within the confidence interval estimated by ERUPT"
392392
]
393393
},
394394
{

0 commit comments

Comments
 (0)