Skip to content

Commit 8f7fb9a

Browse files
committed
policy aipw v3
1 parent 1a6025b commit 8f7fb9a

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

book/cate_and_policy/policy_learning.ipynb

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,224 @@
282282
"print(f\"Mean outcome (untreated): {np.mean(Y[W == 0]):.6f}\")\n",
283283
"print(f\"Overall treatment effect: {np.mean(Y[W == 1]) - np.mean(Y[W == 0]):.6f}\")"
284284
]
285+
},
286+
{
287+
"cell_type": "code",
288+
"execution_count": null,
289+
"metadata": {},
290+
"outputs": [],
291+
"source": [
292+
"# Generate observational data\n",
293+
"np.random.seed(123)\n",
294+
"n = 1000\n",
295+
"p = 4\n",
296+
"X = np.random.uniform(0, 1, (n, p))\n",
297+
"e = 1 / (1 + np.exp(-2*(X[:, 0] - 0.5) - 2*(X[:, 1] - 0.5))) # not observed by analyst\n",
298+
"W = np.random.binomial(1, e, n)\n",
299+
"Y = 0.5 * (X[:, 0] - 0.5) + (X[:, 1] - 0.5) * W + 0.1 * np.random.randn(n)"
300+
]
301+
},
302+
{
303+
"cell_type": "code",
304+
"execution_count": null,
305+
"metadata": {},
306+
"outputs": [],
307+
"source": [
308+
"y_norm = (Y - Y.min()) / (Y.max() - Y.min())\n",
309+
"\n",
310+
"# Plot by treatment status\n",
311+
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
312+
"\n",
313+
"# Untreated\n",
314+
"untreated_idx = W == 0\n",
315+
"for i in np.where(untreated_idx)[0]:\n",
316+
" ax1.scatter(X[i, 0], X[i, 1], marker='D', s=80, \n",
317+
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n",
318+
" edgecolors='black', linewidths=1)\n",
319+
"ax1.set_xlabel('X1')\n",
320+
"ax1.set_ylabel('X2')\n",
321+
"ax1.set_title('Untreated')\n",
322+
"\n",
323+
"# Treated\n",
324+
"treated_idx = W == 1\n",
325+
"for i in np.where(treated_idx)[0]:\n",
326+
" ax2.scatter(X[i, 0], X[i, 1], marker='o', s=100, \n",
327+
" c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n",
328+
" edgecolors='black', linewidths=1)\n",
329+
"ax2.set_xlabel('X1')\n",
330+
"ax2.set_ylabel('X2')\n",
331+
"ax2.set_title('Treated')\n",
332+
"plt.show()"
333+
]
334+
},
335+
{
336+
"cell_type": "code",
337+
"execution_count": null,
338+
"metadata": {},
339+
"outputs": [],
340+
"source": [
341+
"from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n",
342+
"from sklearn.model_selection import KFold"
343+
]
344+
},
345+
{
346+
"cell_type": "code",
347+
"execution_count": null,
348+
"metadata": {},
349+
"outputs": [],
350+
"source": [
351+
"class CausalForest:\n",
352+
" \"\"\"\n",
353+
" Simplified Causal Forest implementation to match grf package behavior\n",
354+
" \"\"\"\n",
355+
" def __init__(self, n_estimators=2000, max_features='sqrt', min_samples_leaf=5, \n",
356+
" honest=True, W_hat=None):\n",
357+
" self.n_estimators = n_estimators\n",
358+
" self.max_features = max_features\n",
359+
" self.min_samples_leaf = min_samples_leaf\n",
360+
" self.honest = honest\n",
361+
" self.W_hat_fixed = W_hat\n",
362+
" \n",
363+
" def fit(self, X, Y, W):\n",
364+
" self.X = X\n",
365+
" self.Y = Y\n",
366+
" self.W = W\n",
367+
" n = len(Y)\n",
368+
" \n",
369+
" # If W.hat is provided (randomized setting), use it\n",
370+
" if self.W_hat_fixed is not None:\n",
371+
" self.W_hat = np.full(n, self.W_hat_fixed)\n",
372+
" else:\n",
373+
" # Estimate propensity score\n",
374+
" ps_model = RandomForestClassifier(\n",
375+
" n_estimators=self.n_estimators//2,\n",
376+
" max_features=self.max_features,\n",
377+
" min_samples_leaf=self.min_samples_leaf,\n",
378+
" random_state=42\n",
379+
" )\n",
380+
" ps_model.fit(X, W)\n",
381+
" self.W_hat = ps_model.predict_proba(X)[:, 1]\n",
382+
" # Clip to avoid division issues\n",
383+
" self.W_hat = np.clip(self.W_hat, 0.01, 0.99)\n",
384+
" \n",
385+
" # Estimate outcome model\n",
386+
" outcome_model = RandomForestRegressor(\n",
387+
" n_estimators=self.n_estimators//2,\n",
388+
" max_features=self.max_features,\n",
389+
" min_samples_leaf=self.min_samples_leaf,\n",
390+
" random_state=42\n",
391+
" )\n",
392+
" outcome_model.fit(X, Y)\n",
393+
" self.Y_hat = outcome_model.predict(X)\n",
394+
" \n",
395+
" # Estimate treatment effects using T-learner\n",
396+
" # Model for treated\n",
397+
" model_1 = RandomForestRegressor(\n",
398+
" n_estimators=self.n_estimators//2,\n",
399+
" max_features=self.max_features,\n",
400+
" min_samples_leaf=self.min_samples_leaf,\n",
401+
" random_state=42\n",
402+
" )\n",
403+
" if np.sum(W == 1) > 0:\n",
404+
" model_1.fit(X[W == 1], Y[W == 1])\n",
405+
" self.mu_1 = model_1.predict(X)\n",
406+
" else:\n",
407+
" self.mu_1 = np.zeros(n)\n",
408+
" \n",
409+
" # Model for control\n",
410+
" model_0 = RandomForestRegressor(\n",
411+
" n_estimators=self.n_estimators//2,\n",
412+
" max_features=self.max_features,\n",
413+
" min_samples_leaf=self.min_samples_leaf,\n",
414+
" random_state=42\n",
415+
" )\n",
416+
" if np.sum(W == 0) > 0:\n",
417+
" model_0.fit(X[W == 0], Y[W == 0])\n",
418+
" self.mu_0 = model_0.predict(X)\n",
419+
" else:\n",
420+
" self.mu_0 = np.zeros(n)\n",
421+
" \n",
422+
" # Treatment effect\n",
423+
" self.tau_hat = self.mu_1 - self.mu_0\n",
424+
" \n",
425+
" return self\n",
426+
" \n",
427+
" def predict(self):\n",
428+
" return {'predictions': self.tau_hat}"
429+
]
430+
},
431+
{
432+
"cell_type": "code",
433+
"execution_count": null,
434+
"metadata": {},
435+
"outputs": [],
436+
"source": [
437+
"# Fit causal forest\n",
438+
"print(\"\\nFitting causal forest...\")\n",
439+
"forest = CausalForest()\n",
440+
"forest.fit(X, Y, W)\n",
441+
"\n",
442+
"# Get predictions\n",
443+
"tau_hat = forest.predict()['predictions']\n",
444+
"\n",
445+
"# Estimate outcome models\n",
446+
"mu_hat_1 = forest.Y_hat + (1 - forest.W_hat) * tau_hat\n",
447+
"mu_hat_0 = forest.Y_hat - forest.W_hat * tau_hat\n",
448+
"\n",
449+
"# Compute AIPW scores\n",
450+
"gamma_hat_1 = mu_hat_1 + W/forest.W_hat * (Y - mu_hat_1)\n",
451+
"gamma_hat_0 = mu_hat_0 + (1-W)/(1-forest.W_hat) * (Y - mu_hat_0)\n",
452+
"\n",
453+
"print(\"Causal forest fitted successfully.\")"
454+
]
455+
},
456+
{
457+
"cell_type": "code",
458+
"execution_count": null,
459+
"metadata": {},
460+
"outputs": [],
461+
"source": [
462+
"# POLICY EVALUATION WITH AIPW\n",
463+
"print(\"\\n--- Policy A (X1 > 0.5 & X2 > 0.5) with AIPW ---\")\n",
464+
"pi = (X[:, 0] > 0.5) & (X[:, 1] > 0.5)\n",
465+
"gamma_hat_pi = pi * gamma_hat_1 + (1 - pi) * gamma_hat_0\n",
466+
"value_estimate = np.mean(gamma_hat_pi)\n",
467+
"value_stderr = np.std(gamma_hat_pi) / np.sqrt(len(gamma_hat_pi))\n",
468+
"print(f\"Value estimate: {value_estimate:.10f} Std. Error: {value_stderr:.10f}\")\n",
469+
"\n",
470+
"print(\"\\n--- Random Policy (p=0.75) with AIPW ---\")\n",
471+
"pi_random = 0.75\n",
472+
"gamma_hat_pi = pi_random * gamma_hat_1 + (1 - pi_random) * gamma_hat_0\n",
473+
"value_estimate = np.mean(gamma_hat_pi)\n",
474+
"value_stderr = np.std(gamma_hat_pi) / np.sqrt(len(gamma_hat_pi))\n",
475+
"print(f\"Value estimate: {value_estimate:.10f} Std. Error: {value_stderr:.10f}\")\n",
476+
"print(\"\\n--- Difference: Policy A vs Never Treat ---\")"
477+
]
478+
},
479+
{
480+
"cell_type": "code",
481+
"execution_count": null,
482+
"metadata": {},
483+
"outputs": [],
484+
"source": [
485+
"# AIPW scores for Policy A\n",
486+
"pi = (X[:, 0] > 0.5) & (X[:, 1] > 0.5)\n",
487+
"gamma_hat_pi = pi * gamma_hat_1 + (1 - pi) * gamma_hat_0\n",
488+
"\n",
489+
"# AIPW scores for Never Treat\n",
490+
"pi_never = 0\n",
491+
"gamma_hat_pi_never = pi_never * gamma_hat_1 + (1 - pi_never) * gamma_hat_0\n",
492+
"\n",
493+
"# Difference\n",
494+
"diff_scores = gamma_hat_pi - gamma_hat_pi_never\n",
495+
"diff_estimate = np.mean(diff_scores)\n",
496+
"diff_stderr = np.std(diff_scores) / np.sqrt(len(diff_scores))\n",
497+
"print(f\"diff estimate: {diff_estimate:.10f} Std. Error: {diff_stderr:.10f}\")\n",
498+
"\n",
499+
"print(\"\\n\" + \"=\" * 70) \n",
500+
"print(\"ANALYSIS COMPLETE\")\n",
501+
"print(\"=\" * 70)"
502+
]
285503
}
286504
],
287505
"metadata": {

0 commit comments

Comments
 (0)