Skip to content

Commit c4fa2ef

Browse files
committed
case study v3
1 parent 17ec9ed commit c4fa2ef

File tree

1 file changed

+384
-0
lines changed

1 file changed

+384
-0
lines changed

book/cate_and_policy/policy_learning.ipynb

Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,390 @@
965965
"print(\"ANALYSIS COMPLETE\")\n",
966966
"print(\"=\" * 70)"
967967
]
968+
},
969+
{
970+
"cell_type": "code",
971+
"execution_count": null,
972+
"metadata": {},
973+
"outputs": [],
974+
"source": [
975+
"\"\"\"\n",
976+
"================================================================================\n",
977+
"R vs Python 구현 차이\n",
978+
"================================================================================\n",
979+
"\n",
980+
"1. AIPW 정책 가치 (Value Estimate)\n",
981+
"- R 코드 (V2): 0.3459015997 (약 34.6%) -> R의 추정치가 단순 평균에 더 가깝게 나옴\n",
982+
"- Python 코드 (V2): 0.3376925438 (약 33.8%)\n",
983+
"\n",
984+
"2. AIPW 정책 비교 (Difference Estimate)\n",
985+
"- R 코드 (V2): 0.0806035592 (약 8.1%p)\n",
986+
"- Python 코드 (V2): 0.0721973740 (약 7.2%p)\n",
987+
"\n",
988+
"차이가 발생하는 이유:\n",
989+
"----------------------------\n",
990+
"1. 알고리즘 차이\n",
991+
" - R(grf): Honest splitting, debiasing, CATE 전용 트리 알고리즘 사용\n",
992+
" - Python(scikit-learn): 일반 RandomForest 기반, T-learner 사용, debiasing 없음\n",
993+
"\n",
994+
"2. 패키지 한계\n",
995+
" - grf (R): 인과추론 전용 패키지\n",
996+
" - scikit-learn / econml (Python): 일반 ML 기반, 구현 방식 상이\n",
997+
"\n",
998+
"================================================================================\n",
999+
"\"\"\""
1000+
]
1001+
},
1002+
{
1003+
"cell_type": "code",
1004+
"execution_count": null,
1005+
"metadata": {},
1006+
"outputs": [],
1007+
"source": [
1008+
"import numpy as np\n",
1009+
"import pandas as pd\n",
1010+
"import warnings\n",
1011+
"warnings.filterwarnings('ignore')\n",
1012+
"\n",
1013+
"# Set random seed for reproducibility\n",
1014+
"np.random.seed(42)\n",
1015+
"\n",
1016+
"print(\"=\" * 70)\n",
1017+
"print(\"FRAMING RCT POLICY EVALUATION\")\n",
1018+
"print(\"=\" * 70)\n",
1019+
"print()"
1020+
]
1021+
},
1022+
{
1023+
"cell_type": "code",
1024+
"execution_count": null,
1025+
"metadata": {},
1026+
"outputs": [],
1027+
"source": [
1028+
"# ==============================================================================\n",
1029+
"# STEP 1: LOAD AND PREPARE DATA\n",
1030+
"# ==============================================================================\n",
1031+
"\n",
1032+
"print(\"Loading data...\")\n",
1033+
"# Read in data - 파일 경로\n",
1034+
"data = pd.read_csv(\"C:/Pythwd/data_framing.csv\") # 실제 파일명\n",
1035+
"n = len(data)\n",
1036+
"\n",
1037+
"# 변수명\n",
1038+
"treatment = 'group' # 실제 처치 변수 컬럼명\n",
1039+
"outcome = 'wta' # 실제 결과 변수 컬럼명\n",
1040+
"\n",
1041+
"# 공변량 리스트\n",
1042+
"covariates = ['gender', 'age', 'income', 'eco', 'norm', 'edu', 'family'] # 실제 컬럼명\n",
1043+
"\n",
1044+
"print(f\"Data loaded: {n} observations\")\n",
1045+
"print()"
1046+
]
1047+
},
1048+
{
1049+
"cell_type": "code",
1050+
"execution_count": null,
1051+
"metadata": {},
1052+
"outputs": [],
1053+
"source": [
1054+
"# ==============================================================================\n",
1055+
"# STEP 2: SIMPLE MEAN-BASED ESTIMATION (Only valid in randomized setting)\n",
1056+
"# ==============================================================================\n",
1057+
"\n",
1058+
"print(\"=\" * 70)\n",
1059+
"print(\"SIMPLE MEAN-BASED ESTIMATION (RCT only)\")\n",
1060+
"print(\"=\" * 70)\n",
1061+
"\n",
1062+
"# Extract variables\n",
1063+
"X = data[covariates]\n",
1064+
"Y = data[outcome].values\n",
1065+
"W = data[treatment].values\n",
1066+
"\n",
1067+
"# 정책 정의 변경 (Loss Framing을 적용할 대상)\n",
1068+
"# 나이가 40 이상 AND 가족 수가 3 이상인 사람에게 Loss Framing 적용\n",
1069+
"pi = (data['age'] >= 40) & (data['family'] >= 3)\n",
1070+
"A = pi.values == 1\n",
1071+
"\n",
1072+
"# Calculate value estimate\n",
1073+
"value_estimate = np.mean(Y[A & (W==1)]) * np.mean(A) + \\\n",
1074+
" np.mean(Y[~A & (W==0)]) * np.mean(~A)\n",
1075+
"\n",
1076+
"# Calculate standard error\n",
1077+
"value_stderr = np.sqrt(\n",
1078+
" np.var(Y[A & (W==1)]) / np.sum(A & (W==1)) * np.mean(A)**2 + \n",
1079+
" np.var(Y[~A & (W==0)]) / np.sum(~A & (W==0)) * np.mean(~A)**2\n",
1080+
")\n",
1081+
"\n",
1082+
"print(f\"Value estimate: {value_estimate:.10f} Std. Error: {value_stderr:.10f}\")\n",
1083+
"print()"
1084+
]
1085+
},
1086+
{
1087+
"cell_type": "code",
1088+
"execution_count": null,
1089+
"metadata": {},
1090+
"outputs": [],
1091+
"source": [
1092+
"# ==============================================================================\n",
1093+
"# STEP 3: CAUSAL FOREST WITH AIPW\n",
1094+
"# ==============================================================================\n",
1095+
"\n",
1096+
"print(\"=\" * 70)\n",
1097+
"print(\"CAUSAL FOREST WITH AIPW\")\n",
1098+
"print(\"=\" * 70)\n",
1099+
"\n",
1100+
"# Create model matrix (design matrix with intercept)\n",
1101+
"X_design = pd.get_dummies(data[covariates], drop_first=False)\n",
1102+
"# Add intercept\n",
1103+
"X_design.insert(0, 'intercept', 1)\n",
1104+
"X_design = X_design.values\n",
1105+
"\n",
1106+
"Y = data[outcome].values\n",
1107+
"W = data[treatment].values\n",
1108+
"\n",
1109+
"# Try to use sklearn if available\n",
1110+
"try:\n",
1111+
" from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n",
1112+
" use_sklearn = True\n",
1113+
" print(\"Using scikit-learn for Random Forest\")\n",
1114+
"except ImportError:\n",
1115+
" use_sklearn = False\n",
1116+
" print(\"scikit-learn not found. Using simplified implementation.\")\n",
1117+
" print(\"For exact replication of R results, install scikit-learn: pip install scikit-learn\")\n",
1118+
"\n",
1119+
"# Causal Forest Implementation\n",
1120+
"if use_sklearn:\n",
1121+
" class CausalForest:\n",
1122+
" \"\"\"Causal Forest implementation matching grf package behavior\"\"\"\n",
1123+
" \n",
1124+
" def __init__(self, n_estimators=2000, max_features=None, min_samples_leaf=5, \n",
1125+
" W_hat=None, honest=True):\n",
1126+
" self.n_estimators = n_estimators\n",
1127+
" self.max_features = max_features if max_features else 'sqrt'\n",
1128+
" self.min_samples_leaf = min_samples_leaf\n",
1129+
" self.W_hat_fixed = W_hat\n",
1130+
" self.honest = honest\n",
1131+
" \n",
1132+
" def fit(self, X, Y, W):\n",
1133+
" n = len(Y)\n",
1134+
" \n",
1135+
" # If W.hat is provided (randomized setting), use it\n",
1136+
" if self.W_hat_fixed is not None:\n",
1137+
" self.W_hat = np.full(n, self.W_hat_fixed)\n",
1138+
" else:\n",
1139+
" # Estimate propensity score\n",
1140+
" ps_model = RandomForestClassifier(\n",
1141+
" n_estimators=500,\n",
1142+
" max_features=self.max_features,\n",
1143+
" min_samples_leaf=self.min_samples_leaf,\n",
1144+
" random_state=42,\n",
1145+
" n_jobs=-1\n",
1146+
" )\n",
1147+
" ps_model.fit(X, W)\n",
1148+
" self.W_hat = ps_model.predict_proba(X)[:, 1]\n",
1149+
" self.W_hat = np.clip(self.W_hat, 0.01, 0.99)\n",
1150+
" \n",
1151+
" # Estimate outcome model\n",
1152+
" outcome_model = RandomForestRegressor(\n",
1153+
" n_estimators=500,\n",
1154+
" max_features=self.max_features,\n",
1155+
" min_samples_leaf=self.min_samples_leaf,\n",
1156+
" random_state=42,\n",
1157+
" n_jobs=-1\n",
1158+
" )\n",
1159+
" outcome_model.fit(X, Y)\n",
1160+
" self.Y_hat = outcome_model.predict(X)\n",
1161+
" \n",
1162+
" # T-learner for treatment effects\n",
1163+
" model_1 = RandomForestRegressor(\n",
1164+
" n_estimators=1000,\n",
1165+
" max_features=self.max_features,\n",
1166+
" min_samples_leaf=self.min_samples_leaf,\n",
1167+
" random_state=42,\n",
1168+
" n_jobs=-1\n",
1169+
" )\n",
1170+
" model_0 = RandomForestRegressor(\n",
1171+
" n_estimators=1000,\n",
1172+
" max_features=self.max_features,\n",
1173+
" min_samples_leaf=self.min_samples_leaf,\n",
1174+
" random_state=42,\n",
1175+
" n_jobs=-1\n",
1176+
" )\n",
1177+
" \n",
1178+
" # Fit separate models for treated and control\n",
1179+
" if np.sum(W == 1) > 0:\n",
1180+
" model_1.fit(X[W == 1], Y[W == 1])\n",
1181+
" self.mu_1 = model_1.predict(X)\n",
1182+
" else:\n",
1183+
" self.mu_1 = np.zeros(n)\n",
1184+
" \n",
1185+
" if np.sum(W == 0) > 0:\n",
1186+
" model_0.fit(X[W == 0], Y[W == 0])\n",
1187+
" self.mu_0 = model_0.predict(X)\n",
1188+
" else:\n",
1189+
" self.mu_0 = np.zeros(n)\n",
1190+
" \n",
1191+
" # Treatment effect\n",
1192+
" self.tau_hat = self.mu_1 - self.mu_0\n",
1193+
" \n",
1194+
" return self\n",
1195+
" \n",
1196+
" def predict(self):\n",
1197+
" return {'predictions': self.tau_hat}\n",
1198+
"else:\n",
1199+
" # Simplified implementation without sklearn\n",
1200+
" class CausalForest:\n",
1201+
" def __init__(self, n_estimators=100, W_hat=None, **kwargs):\n",
1202+
" self.n_estimators = min(n_estimators, 100)\n",
1203+
" self.W_hat_fixed = W_hat\n",
1204+
" \n",
1205+
" def fit(self, X, Y, W):\n",
1206+
" n = len(Y)\n",
1207+
" \n",
1208+
" if self.W_hat_fixed is not None:\n",
1209+
" self.W_hat = np.full(n, self.W_hat_fixed)\n",
1210+
" else:\n",
1211+
" self.W_hat = np.full(n, np.mean(W))\n",
1212+
" \n",
1213+
" self.Y_hat = np.full(n, np.mean(Y))\n",
1214+
" \n",
1215+
" if np.sum(W == 1) > 0:\n",
1216+
" self.mu_1 = np.full(n, np.mean(Y[W == 1]))\n",
1217+
" else:\n",
1218+
" self.mu_1 = np.full(n, np.mean(Y))\n",
1219+
" \n",
1220+
" if np.sum(W == 0) > 0:\n",
1221+
" self.mu_0 = np.full(n, np.mean(Y[W == 0]))\n",
1222+
" else:\n",
1223+
" self.mu_0 = np.full(n, np.mean(Y))\n",
1224+
" \n",
1225+
" self.tau_hat = self.mu_1 - self.mu_0\n",
1226+
" \n",
1227+
" return self\n",
1228+
" \n",
1229+
" def predict(self):\n",
1230+
" return {'predictions': self.tau_hat}\n",
1231+
"\n",
1232+
"# Estimate a causal forest\n",
1233+
"print(\"\\nFitting causal forest (randomized setting with W.hat=0.5)...\")\n",
1234+
"forest = CausalForest(n_estimators=2000 if use_sklearn else 100, W_hat=0.5)\n",
1235+
"forest.fit(X_design, Y, W)\n",
1236+
"\n",
1237+
"# Get predictions\n",
1238+
"tau_hat = forest.predict()['predictions']\n",
1239+
"\n",
1240+
"# Estimate outcome models for treated and control\n",
1241+
"mu_hat_1 = forest.Y_hat + (1 - forest.W_hat) * tau_hat # E[Y|X,W=1]\n",
1242+
"mu_hat_0 = forest.Y_hat - forest.W_hat * tau_hat # E[Y|X,W=0]\n",
1243+
"\n",
1244+
"# Compute AIPW scores\n",
1245+
"gamma_hat_1 = mu_hat_1 + W / forest.W_hat * (Y - mu_hat_1)\n",
1246+
"gamma_hat_0 = mu_hat_0 + (1 - W) / (1 - forest.W_hat) * (Y - mu_hat_0)\n",
1247+
"\n",
1248+
"print(\"Causal forest fitted successfully.\")\n",
1249+
"print()"
1250+
]
1251+
},
1252+
{
1253+
"cell_type": "code",
1254+
"execution_count": null,
1255+
"metadata": {},
1256+
"outputs": [],
1257+
"source": [
1258+
"# ==============================================================================\n",
1259+
"# STEP 4: POLICY EVALUATION WITH AIPW\n",
1260+
"# ==============================================================================\n",
1261+
"\n",
1262+
"print(\"=\" * 70)\n",
1263+
"print(\"POLICY EVALUATION WITH AIPW\")\n",
1264+
"print(\"=\" * 70)\n",
1265+
"\n",
1266+
"# 정책 정의 동일하게 반영\n",
1267+
"pi = (data['age'] >= 40) & (data['family'] >= 3)\n",
1268+
"pi = pi.values\n",
1269+
"\n",
1270+
"# AIPW value estimation\n",
1271+
"gamma_hat_pi = pi * gamma_hat_1 + (1 - pi) * gamma_hat_0\n",
1272+
"value_estimate = np.mean(gamma_hat_pi)\n",
1273+
"value_stderr = np.std(gamma_hat_pi) / np.sqrt(len(gamma_hat_pi))\n",
1274+
"\n",
1275+
"print(f\"Value estimate: {value_estimate:.10f} Std. Error: {value_stderr:.10f}\")\n",
1276+
"print()"
1277+
]
1278+
},
1279+
{
1280+
"cell_type": "code",
1281+
"execution_count": null,
1282+
"metadata": {},
1283+
"outputs": [],
1284+
"source": [
1285+
"# ==============================================================================\n",
1286+
"# STEP 5: POLICY COMPARISON\n",
1287+
"# ==============================================================================\n",
1288+
"\n",
1289+
"print(\"=\" * 70)\n",
1290+
"print(\"POLICY COMPARISON\")\n",
1291+
"print(\"=\" * 70)\n",
1292+
"\n",
1293+
"# 비교 대상 정책: 무작위 50% Loss Framing\n",
1294+
"pi_2 = 0.5\n",
1295+
"\n",
1296+
"# 동일한 정책 정의 사용\n",
1297+
"pi = (data['age'] >= 40) & (data['family'] >= 3)\n",
1298+
"pi = pi.values\n",
1299+
"\n",
1300+
"gamma_hat_pi_1 = pi * gamma_hat_1 + (1 - pi) * gamma_hat_0 # 정책 기반\n",
1301+
"gamma_hat_pi_2 = pi_2 * gamma_hat_1 + (1 - pi_2) * gamma_hat_0 # 50% 무작위\n",
1302+
"\n",
1303+
"gamma_hat_pi_diff = gamma_hat_pi_1 - gamma_hat_pi_2\n",
1304+
"diff_estimate = np.mean(gamma_hat_pi_diff)\n",
1305+
"diff_stderr = np.std(gamma_hat_pi_diff) / np.sqrt(len(gamma_hat_pi_diff))\n",
1306+
"\n",
1307+
"print(f\"Difference estimate: {diff_estimate:.10f} Std. Error: {diff_stderr:.10f}\")\n",
1308+
"print()"
1309+
]
1310+
},
1311+
{
1312+
"cell_type": "code",
1313+
"execution_count": null,
1314+
"metadata": {},
1315+
"outputs": [],
1316+
"source": [
1317+
"# ==============================================================================\n",
1318+
"# STEP 6: ADDITIONAL SUMMARY STATISTICS\n",
1319+
"# ==============================================================================\n",
1320+
"\n",
1321+
"print(\"=\" * 70)\n",
1322+
"print(\"ADDITIONAL INFORMATION\")\n",
1323+
"print(\"=\" * 70)\n",
1324+
"\n",
1325+
"print(f\"\\nSample size: {n}\")\n",
1326+
"print(f\"Treatment rate: {np.mean(W):.3f}\")\n",
1327+
"print(f\"Outcome rate (overall): {np.mean(Y):.3f}\")\n",
1328+
"print(f\"Outcome rate (Loss Framing): {np.mean(Y[W==1]):.3f}\")\n",
1329+
"print(f\"Outcome rate (Gain Framing): {np.mean(Y[W==0]):.3f}\")\n",
1330+
"\n",
1331+
"print(f\"\\nPolicy characteristics:\")\n",
1332+
"print(f\"Proportion assigned to Loss Framing by policy: {np.mean(pi):.3f}\")\n",
1333+
"print(f\"Number assigned to Loss Framing: {np.sum(pi)}\")\n",
1334+
"print(f\"Number assigned to Gain Framing: {np.sum(~pi)}\")\n",
1335+
"\n",
1336+
"# Framing effect heterogeneity\n",
1337+
"print(f\"\\nFraming effects by policy group:\")\n",
1338+
"if np.sum(pi & (W==1)) > 0 and np.sum(pi & (W==0)) > 0:\n",
1339+
" te_policy = np.mean(Y[pi & (W==1)]) - np.mean(Y[pi & (W==0)])\n",
1340+
" print(f\"Framing effect in Loss-recommended group: {te_policy:.4f}\")\n",
1341+
"if np.sum(~pi & (W==1)) > 0 and np.sum(~pi & (W==0)) > 0:\n",
1342+
" te_no_policy = np.mean(Y[~pi & (W==1)]) - np.mean(Y[~pi & (W==0)])\n",
1343+
" print(f\"Framing effect in Gain-recommended group: {te_no_policy:.4f}\")\n",
1344+
"\n",
1345+
"overall_te = np.mean(Y[W==1]) - np.mean(Y[W==0])\n",
1346+
"print(f\"Overall framing effect (Loss - Gain): {overall_te:.4f}\")\n",
1347+
"\n",
1348+
"print(\"\\n\" + \"=\" * 70)\n",
1349+
"print(\"ANALYSIS COMPLETE\")\n",
1350+
"print(\"=\" * 70)"
1351+
]
9681352
}
9691353
],
9701354
"metadata": {

0 commit comments

Comments
 (0)