Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 3 additions & 27 deletions causaltune/optimiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from joblib import Parallel, delayed

from causaltune.search.params import SimpleParamService
from causaltune.score.scoring import Scorer
from causaltune.score.scoring import Scorer, metrics_to_minimize
from causaltune.utils import treatment_is_multivalue
from causaltune.models.monkey_patches import (
AutoML,
Expand Down Expand Up @@ -514,19 +514,7 @@ def fit(
evaluated_rewards=(
[] if len(self.resume_scores) == 0 else self.resume_scores
),
mode=(
"min"
if self.metric
in [
"energy_distance",
"psw_energy_distance",
"frobenius_norm",
"psw_frobenius_norm",
"codec",
"policy_risk",
]
else "max"
),
mode=("min" if self.metric in metrics_to_minimize() else "max"),
low_cost_partial_config={},
**self._settings["tuner"],
)
Expand All @@ -547,19 +535,7 @@ def fit(
evaluated_rewards=(
[] if len(self.resume_scores) == 0 else self.resume_scores
),
mode=(
"min"
if self.metric
in [
"energy_distance",
"psw_energy_distance",
"frobenius_norm",
"psw_frobenius_norm",
"codec",
"policy_risk",
]
else "max"
),
mode=("min" if self.metric in metrics_to_minimize() else "max"),
low_cost_partial_config={},
**self._settings["tuner"],
)
Expand Down
4 changes: 2 additions & 2 deletions causaltune/score/erupt_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def erupt_with_std(
]
mean += np.mean(means)
std += np.std(means) / np.sqrt(num_splits) # Standard error of the mean

return mean / resamples, std / resamples
# 1.5 is an empirical factor to make the confidence interval wider
return mean / resamples, 1.5 * std / resamples


def erupt(
Expand Down
21 changes: 18 additions & 3 deletions causaltune/score/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def const_marginal_effect(self, X):
return self.cate_estimate


def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List[str]:
def supported_metrics(problem: str, multivalue: bool, scores_only: bool, constant_ptt: bool=False) -> List[str]:
if problem == "iv":
metrics = ["energy_distance", "frobenius_norm", "codec"]
if not scores_only:
Expand All @@ -52,12 +52,12 @@ def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List
metrics = [
"erupt",
"norm_erupt",
"greedy_erupt", # regular erupt was made probabilistic, no need for a separate one
# "greedy_erupt", # regular erupt was made probabilistic, no need for a separate one
"policy_risk", # NEW
"qini",
"auc",
# "r_scorer",
"energy_distance", # is broken without propensity weighting
"energy_distance", # should only be used in iv problems
"psw_energy_distance",
"frobenius_norm", # NEW
"codec", # NEW
Expand All @@ -68,6 +68,17 @@ def supported_metrics(problem: str, multivalue: bool, scores_only: bool) -> List
return metrics


def metrics_to_minimize():
return [
"energy_distance",
"psw_energy_distance",
"codec",
"frobenius_norm",
"psw_frobenius_norm",
"policy_risk",
]


class Scorer:
def __init__(
self,
Expand All @@ -90,6 +101,10 @@ def __init__(
self.identified_estimand = causal_model.identify_effect(
proceed_when_unidentifiable=True
)
if "Dummy" in propensity_model.__class__.__name__:
self.constant_ptt = True
else:
self.constant_ptt = False

if problem == "backdoor":
print(
Expand Down
21 changes: 20 additions & 1 deletion causaltune/search/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def joint_config(data_size: Tuple[int, int], estimator_list=None):
cfg, init_params, low_cost_init_params = flaml_config_to_tune_config(
cls.search_space(data_size=data_size, task=task)
)

cfg, init_params = tweak_config(cfg, init_params, name)
# Test if the estimator instantiates fine
try:
cls(task=task, **init_params)
Expand All @@ -76,6 +76,25 @@ def joint_config(data_size: Tuple[int, int], estimator_list=None):
return tune.choice(joint_cfg), joint_init_params, joint_low_cost_init_params


def tweak_config(cfg: dict, init_params: dict, estimator_name: str):
"""
Tweak built-in FLAML search spaces to limit the number of estimators
:param cfg:
:param estimator_name:
:return:
"""
out = copy.deepcopy(cfg)
if "xgboost" in estimator_name or estimator_name in [
"random_forest",
"extra_trees",
"lgbm",
"catboost",
]:
out["n_estimators"] = tune.lograndint(4, 1000)
init_params["n_estimators"] = 100
return out, init_params


def model_from_cfg(cfg: dict):
cfg = copy.deepcopy(cfg)
model_name = cfg.pop("estimator_name")
Expand Down
130 changes: 65 additions & 65 deletions notebooks/ERUPT basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,45 +103,45 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.452636</td>\n",
" <td>0</td>\n",
" <td>1.684484</td>\n",
" <td>0.898227</td>\n",
" <td>1</td>\n",
" <td>1.288637</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.380215</td>\n",
" <td>0.462092</td>\n",
" <td>0</td>\n",
" <td>0.745268</td>\n",
" <td>0.771976</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.584036</td>\n",
" <td>1</td>\n",
" <td>0.762300</td>\n",
" <td>0.858974</td>\n",
" <td>0</td>\n",
" <td>1.881019</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.505191</td>\n",
" <td>0</td>\n",
" <td>1.425354</td>\n",
" <td>0.228084</td>\n",
" <td>1</td>\n",
" <td>0.357797</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.384110</td>\n",
" <td>0.962512</td>\n",
" <td>1</td>\n",
" <td>1.834628</td>\n",
" <td>1.066413</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" X T1 Y1\n",
"0 0.452636 0 1.684484\n",
"1 0.380215 0 0.745268\n",
"2 0.584036 1 0.762300\n",
"3 0.505191 0 1.425354\n",
"4 0.384110 1 1.834628"
"0 0.898227 1 1.288637\n",
"1 0.462092 0 0.771976\n",
"2 0.858974 0 1.881019\n",
"3 0.228084 1 0.357797\n",
"4 0.962512 1 1.066413"
]
},
"execution_count": 2,
Expand Down Expand Up @@ -216,65 +216,65 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.452636</td>\n",
" <td>0</td>\n",
" <td>1.684484</td>\n",
" <td>0.726318</td>\n",
" <td>0</td>\n",
" <td>0.273682</td>\n",
" <td>0.904259</td>\n",
" <td>0.898227</td>\n",
" <td>1</td>\n",
" <td>1.288637</td>\n",
" <td>0.949114</td>\n",
" <td>1</td>\n",
" <td>0.949114</td>\n",
" <td>2.229118</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.380215</td>\n",
" <td>0.462092</td>\n",
" <td>0</td>\n",
" <td>0.745268</td>\n",
" <td>0.690108</td>\n",
" <td>1</td>\n",
" <td>0.690108</td>\n",
" <td>1.930383</td>\n",
" <td>0.771976</td>\n",
" <td>0.731046</td>\n",
" <td>0</td>\n",
" <td>0.268954</td>\n",
" <td>0.572308</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.584036</td>\n",
" <td>1</td>\n",
" <td>0.762300</td>\n",
" <td>0.792018</td>\n",
" <td>0.858974</td>\n",
" <td>0</td>\n",
" <td>1.881019</td>\n",
" <td>0.929487</td>\n",
" <td>1</td>\n",
" <td>0.792018</td>\n",
" <td>0.959608</td>\n",
" <td>0.929487</td>\n",
" <td>2.601592</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.505191</td>\n",
" <td>0</td>\n",
" <td>1.425354</td>\n",
" <td>0.752596</td>\n",
" <td>0.228084</td>\n",
" <td>1</td>\n",
" <td>0.357797</td>\n",
" <td>0.614042</td>\n",
" <td>1</td>\n",
" <td>0.752596</td>\n",
" <td>1.017777</td>\n",
" <td>0.614042</td>\n",
" <td>0.542638</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.384110</td>\n",
" <td>0.962512</td>\n",
" <td>1</td>\n",
" <td>1.834628</td>\n",
" <td>0.692055</td>\n",
" <td>1.066413</td>\n",
" <td>0.981256</td>\n",
" <td>1</td>\n",
" <td>0.692055</td>\n",
" <td>2.374030</td>\n",
" <td>0.981256</td>\n",
" <td>2.401383</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" X T1 Y1 p T2 p_of_actual Y2\n",
"0 0.452636 0 1.684484 0.726318 0 0.273682 0.904259\n",
"1 0.380215 0 0.745268 0.690108 1 0.690108 1.930383\n",
"2 0.584036 1 0.762300 0.792018 1 0.792018 0.959608\n",
"3 0.505191 0 1.425354 0.752596 1 0.752596 1.017777\n",
"4 0.384110 1 1.834628 0.692055 1 0.692055 2.374030"
"0 0.898227 1 1.288637 0.949114 1 0.949114 2.229118\n",
"1 0.462092 0 0.771976 0.731046 0 0.268954 0.572308\n",
"2 0.858974 0 1.881019 0.929487 1 0.929487 2.601592\n",
"3 0.228084 1 0.357797 0.614042 1 0.614042 0.542638\n",
"4 0.962512 1 1.066413 0.981256 1 0.981256 2.401383"
]
},
"execution_count": 3,
Expand Down Expand Up @@ -319,10 +319,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average outcome of the actual biased assignment experiment: 1.411675477573636\n",
"Estimated outcome of random assignment: 1.251567372523789\n",
"95% confidence interval for estimated outcome: 1.2311928820519622 1.2719418629956158\n",
"Average outcome of the actual random assignment experiment: 1.2559621877416332\n"
"Average outcome of the actual biased assignment experiment: 1.4064676444383317\n",
"Estimated outcome of random assignment: 1.2594221770638483\n",
"95% confidence interval for estimated outcome: 1.230204391668238 1.2886399624594587\n",
"Average outcome of the actual random assignment experiment: 1.2461659092712785\n"
]
}
],
Expand Down Expand Up @@ -360,17 +360,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average outcome of the actual random assignment experiment: 1.2559621877416332\n",
"Estimated outcome of biased assignment: 1.4147647990746988\n",
"Confidence interval for estimated outcome: 1.398423601541284 1.4311059966081134\n",
"Average outcome of the actual biased assignment experiment: 1.411675477573636\n"
"Average outcome of the actual random assignment experiment: 1.2461659092712785\n",
"Estimated outcome of biased assignment: 1.405112521603215\n",
"95% confidence interval for estimated outcome: 1.3814865905561569 1.428738452650273\n",
"Average outcome of the actual biased assignment experiment: 1.4064676444383317\n"
]
}
],
"source": [
"# Conversely, we can take the outcome of the fully random test and use it to estimate what the outcome of the biased assignment would have been\n",
"# Conversely, we can take the outcome of the fully random test and use it \n",
"# to estimate what the outcome of the biased assignment would have been\n",
"\n",
"# Let's use data from biased assignment experiment to estimate the average effect of fully random assignment\n",
"hypothetical_policy = df[\"T2\"]\n",
"est, std = erupt_with_std(actual_propensity=0.5*pd.Series(np.ones(len(df))), \n",
" actual_treatment=df[\"T1\"],\n",
Expand All @@ -379,7 +379,7 @@
"\n",
"print(\"Average outcome of the actual random assignment experiment:\", df[\"Y1\"].mean())\n",
"print(\"Estimated outcome of biased assignment:\", est)\n",
"print(\"Confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n",
"print(\"95% confidence interval for estimated outcome:\", est-2*std, est + 2*std)\n",
"print(\"Average outcome of the actual biased assignment experiment:\", df[\"Y2\"].mean())"
]
},
Expand All @@ -388,7 +388,7 @@
"id": "f724dbc3",
"metadata": {},
"source": [
"As you can see, the actual outcome is well within the confidence interval estimated by ERUPT"
"As you can see, the actual outcome is within the confidence interval estimated by ERUPT"
]
},
{
Expand Down
Loading
Loading