|
282 | 282 | "print(f\"Mean outcome (untreated): {np.mean(Y[W == 0]):.6f}\")\n", |
283 | 283 | "print(f\"Overall treatment effect: {np.mean(Y[W == 1]) - np.mean(Y[W == 0]):.6f}\")" |
284 | 284 | ] |
| 285 | + }, |
| 286 | + { |
| 287 | + "cell_type": "code", |
| 288 | + "execution_count": null, |
| 289 | + "metadata": {}, |
| 290 | + "outputs": [], |
| 291 | + "source": [ |
| 292 | + "# Generate observational data\n", |
| 293 | + "np.random.seed(123)\n", |
| 294 | + "n = 1000\n", |
| 295 | + "p = 4\n", |
| 296 | + "X = np.random.uniform(0, 1, (n, p))\n", |
| 297 | + "e = 1 / (1 + np.exp(-2*(X[:, 0] - 0.5) - 2*(X[:, 1] - 0.5))) # not observed by analyst\n", |
| 298 | + "W = np.random.binomial(1, e, n)\n", |
| 299 | + "Y = 0.5 * (X[:, 0] - 0.5) + (X[:, 1] - 0.5) * W + 0.1 * np.random.randn(n)" |
| 300 | + ] |
| 301 | + }, |
| 302 | + { |
| 303 | + "cell_type": "code", |
| 304 | + "execution_count": null, |
| 305 | + "metadata": {}, |
| 306 | + "outputs": [], |
| 307 | + "source": [ |
| 308 | + "y_norm = (Y - Y.min()) / (Y.max() - Y.min())\n", |
| 309 | + "\n", |
| 310 | + "# Plot by treatment status\n", |
| 311 | + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n", |
| 312 | + "\n", |
| 313 | + "# Untreated\n", |
| 314 | + "untreated_idx = W == 0\n", |
| 315 | + "for i in np.where(untreated_idx)[0]:\n", |
| 316 | + " ax1.scatter(X[i, 0], X[i, 1], marker='D', s=80, \n", |
| 317 | + " c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n", |
| 318 | + " edgecolors='black', linewidths=1)\n", |
| 319 | + "ax1.set_xlabel('X1')\n", |
| 320 | + "ax1.set_ylabel('X2')\n", |
| 321 | + "ax1.set_title('Untreated')\n", |
| 322 | + "\n", |
| 323 | + "# Treated\n", |
| 324 | + "treated_idx = W == 1\n", |
| 325 | + "for i in np.where(treated_idx)[0]:\n", |
| 326 | + " ax2.scatter(X[i, 0], X[i, 1], marker='o', s=100, \n", |
| 327 | + " c=[y_norm[i]], cmap='gray', vmin=0, vmax=1,\n", |
| 328 | + " edgecolors='black', linewidths=1)\n", |
| 329 | + "ax2.set_xlabel('X1')\n", |
| 330 | + "ax2.set_ylabel('X2')\n", |
| 331 | + "ax2.set_title('Treated')\n", |
| 332 | + "plt.show()" |
| 333 | + ] |
| 334 | + }, |
| 335 | + { |
| 336 | + "cell_type": "code", |
| 337 | + "execution_count": null, |
| 338 | + "metadata": {}, |
| 339 | + "outputs": [], |
| 340 | + "source": [ |
| 341 | + "from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n", |
| 342 | + "from sklearn.model_selection import KFold" |
| 343 | + ] |
| 344 | + }, |
| 345 | + { |
| 346 | + "cell_type": "code", |
| 347 | + "execution_count": null, |
| 348 | + "metadata": {}, |
| 349 | + "outputs": [], |
| 350 | + "source": [ |
| 351 | + "class CausalForest:\n", |
| 352 | + " \"\"\"\n", |
| 353 | + " Simplified Causal Forest implementation to match grf package behavior\n", |
| 354 | + " \"\"\"\n", |
| 355 | + " def __init__(self, n_estimators=2000, max_features='sqrt', min_samples_leaf=5, \n", |
| 356 | + " honest=True, W_hat=None):\n", |
| 357 | + " self.n_estimators = n_estimators\n", |
| 358 | + " self.max_features = max_features\n", |
| 359 | + " self.min_samples_leaf = min_samples_leaf\n", |
| 360 | + " self.honest = honest\n", |
| 361 | + " self.W_hat_fixed = W_hat\n", |
| 362 | + " \n", |
| 363 | + " def fit(self, X, Y, W):\n", |
| 364 | + " self.X = X\n", |
| 365 | + " self.Y = Y\n", |
| 366 | + " self.W = W\n", |
| 367 | + " n = len(Y)\n", |
| 368 | + " \n", |
| 369 | + " # If W.hat is provided (randomized setting), use it\n", |
| 370 | + " if self.W_hat_fixed is not None:\n", |
| 371 | + " self.W_hat = np.full(n, self.W_hat_fixed)\n", |
| 372 | + " else:\n", |
| 373 | + " # Estimate propensity score\n", |
| 374 | + " ps_model = RandomForestClassifier(\n", |
| 375 | + " n_estimators=self.n_estimators//2,\n", |
| 376 | + " max_features=self.max_features,\n", |
| 377 | + " min_samples_leaf=self.min_samples_leaf,\n", |
| 378 | + " random_state=42\n", |
| 379 | + " )\n", |
| 380 | + " ps_model.fit(X, W)\n", |
| 381 | + " self.W_hat = ps_model.predict_proba(X)[:, 1]\n", |
| 382 | + " # Clip to avoid division issues\n", |
| 383 | + " self.W_hat = np.clip(self.W_hat, 0.01, 0.99)\n", |
| 384 | + " \n", |
| 385 | + " # Estimate outcome model\n", |
| 386 | + " outcome_model = RandomForestRegressor(\n", |
| 387 | + " n_estimators=self.n_estimators//2,\n", |
| 388 | + " max_features=self.max_features,\n", |
| 389 | + " min_samples_leaf=self.min_samples_leaf,\n", |
| 390 | + " random_state=42\n", |
| 391 | + " )\n", |
| 392 | + " outcome_model.fit(X, Y)\n", |
| 393 | + " self.Y_hat = outcome_model.predict(X)\n", |
| 394 | + " \n", |
| 395 | + " # Estimate treatment effects using T-learner\n", |
| 396 | + " # Model for treated\n", |
| 397 | + " model_1 = RandomForestRegressor(\n", |
| 398 | + " n_estimators=self.n_estimators//2,\n", |
| 399 | + " max_features=self.max_features,\n", |
| 400 | + " min_samples_leaf=self.min_samples_leaf,\n", |
| 401 | + " random_state=42\n", |
| 402 | + " )\n", |
| 403 | + " if np.sum(W == 1) > 0:\n", |
| 404 | + " model_1.fit(X[W == 1], Y[W == 1])\n", |
| 405 | + " self.mu_1 = model_1.predict(X)\n", |
| 406 | + " else:\n", |
| 407 | + " self.mu_1 = np.zeros(n)\n", |
| 408 | + " \n", |
| 409 | + " # Model for control\n", |
| 410 | + " model_0 = RandomForestRegressor(\n", |
| 411 | + " n_estimators=self.n_estimators//2,\n", |
| 412 | + " max_features=self.max_features,\n", |
| 413 | + " min_samples_leaf=self.min_samples_leaf,\n", |
| 414 | + " random_state=42\n", |
| 415 | + " )\n", |
| 416 | + " if np.sum(W == 0) > 0:\n", |
| 417 | + " model_0.fit(X[W == 0], Y[W == 0])\n", |
| 418 | + " self.mu_0 = model_0.predict(X)\n", |
| 419 | + " else:\n", |
| 420 | + " self.mu_0 = np.zeros(n)\n", |
| 421 | + " \n", |
| 422 | + " # Treatment effect\n", |
| 423 | + " self.tau_hat = self.mu_1 - self.mu_0\n", |
| 424 | + " \n", |
| 425 | + " return self\n", |
| 426 | + " \n", |
| 427 | + " def predict(self):\n", |
| 428 | + " return {'predictions': self.tau_hat}" |
| 429 | + ] |
| 430 | + }, |
| 431 | + { |
| 432 | + "cell_type": "code", |
| 433 | + "execution_count": null, |
| 434 | + "metadata": {}, |
| 435 | + "outputs": [], |
| 436 | + "source": [ |
| 437 | + "# Fit causal forest\n", |
| 438 | + "print(\"\\nFitting causal forest...\")\n", |
| 439 | + "forest = CausalForest()\n", |
| 440 | + "forest.fit(X, Y, W)\n", |
| 441 | + "\n", |
| 442 | + "# Get predictions\n", |
| 443 | + "tau_hat = forest.predict()['predictions']\n", |
| 444 | + "\n", |
| 445 | + "# Estimate outcome models\n", |
| 446 | + "mu_hat_1 = forest.Y_hat + (1 - forest.W_hat) * tau_hat\n", |
| 447 | + "mu_hat_0 = forest.Y_hat - forest.W_hat * tau_hat\n", |
| 448 | + "\n", |
| 449 | + "# Compute AIPW scores\n", |
| 450 | + "gamma_hat_1 = mu_hat_1 + W/forest.W_hat * (Y - mu_hat_1)\n", |
| 451 | + "gamma_hat_0 = mu_hat_0 + (1-W)/(1-forest.W_hat) * (Y - mu_hat_0)\n", |
| 452 | + "\n", |
| 453 | + "print(\"Causal forest fitted successfully.\")" |
| 454 | + ] |
| 455 | + }, |
| 456 | + { |
| 457 | + "cell_type": "code", |
| 458 | + "execution_count": null, |
| 459 | + "metadata": {}, |
| 460 | + "outputs": [], |
| 461 | + "source": [ |
| 462 | + "# POLICY EVALUATION WITH AIPW\n", |
| 463 | + "print(\"\\n--- Policy A (X1 > 0.5 & X2 > 0.5) with AIPW ---\")\n", |
| 464 | + "pi = (X[:, 0] > 0.5) & (X[:, 1] > 0.5)\n", |
| 465 | + "gamma_hat_pi = pi * gamma_hat_1 + (1 - pi) * gamma_hat_0\n", |
| 466 | + "value_estimate = np.mean(gamma_hat_pi)\n", |
| 467 | + "value_stderr = np.std(gamma_hat_pi) / np.sqrt(len(gamma_hat_pi))\n", |
| 468 | + "print(f\"Value estimate: {value_estimate:.10f} Std. Error: {value_stderr:.10f}\")\n", |
| 469 | + "\n", |
| 470 | + "print(\"\\n--- Random Policy (p=0.75) with AIPW ---\")\n", |
| 471 | + "pi_random = 0.75\n", |
| 472 | + "gamma_hat_pi = pi_random * gamma_hat_1 + (1 - pi_random) * gamma_hat_0\n", |
| 473 | + "value_estimate = np.mean(gamma_hat_pi)\n", |
| 474 | + "value_stderr = np.std(gamma_hat_pi) / np.sqrt(len(gamma_hat_pi))\n", |
| 475 | + "print(f\"Value estimate: {value_estimate:.10f} Std. Error: {value_stderr:.10f}\")\n", |
| 476 | + "print(\"\\n--- Difference: Policy A vs Never Treat ---\")" |
| 477 | + ] |
| 478 | + }, |
| 479 | + { |
| 480 | + "cell_type": "code", |
| 481 | + "execution_count": null, |
| 482 | + "metadata": {}, |
| 483 | + "outputs": [], |
| 484 | + "source": [ |
| 485 | + "# AIPW scores for Policy A\n", |
| 486 | + "pi = (X[:, 0] > 0.5) & (X[:, 1] > 0.5)\n", |
| 487 | + "gamma_hat_pi = pi * gamma_hat_1 + (1 - pi) * gamma_hat_0\n", |
| 488 | + "\n", |
| 489 | + "# AIPW scores for Never Treat\n", |
| 490 | + "pi_never = 0\n", |
| 491 | + "gamma_hat_pi_never = pi_never * gamma_hat_1 + (1 - pi_never) * gamma_hat_0\n", |
| 492 | + "\n", |
| 493 | + "# Difference\n", |
| 494 | + "diff_scores = gamma_hat_pi - gamma_hat_pi_never\n", |
| 495 | + "diff_estimate = np.mean(diff_scores)\n", |
| 496 | + "diff_stderr = np.std(diff_scores) / np.sqrt(len(diff_scores))\n", |
| 497 | + "print(f\"diff estimate: {diff_estimate:.10f} Std. Error: {diff_stderr:.10f}\")\n", |
| 498 | + "\n", |
| 499 | + "print(\"\\n\" + \"=\" * 70) \n", |
| 500 | + "print(\"ANALYSIS COMPLETE\")\n", |
| 501 | + "print(\"=\" * 70)" |
| 502 | + ] |
285 | 503 | } |
286 | 504 | ], |
287 | 505 | "metadata": { |
|
0 commit comments