Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 354 additions & 5 deletions book/ate/propensity_score_and_dml.ipynb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

고생많으셨습니다! 다음 주 작업도 파이팅입니다 :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

리뷰 감사합니다!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

이번 한 주도 고생 많으셨습니다!

덕분에 매칭 관련 공부 잘 하고 있어요 👍

Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,353 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"- Matching\n",
"- IPW, AIPW, Doubly Robust Estimator\n",
"- Double Machine Learning (비모수 버전의 Regression 처럼 활용 가능)"
"## Propensity Score Matching - Binary Treatment"
]
},
{
"cell_type": "code",
"execution_count": 165,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>schoolid</th>\n",
" <th>intervention</th>\n",
" <th>achievement_score</th>\n",
" <th>success_expect</th>\n",
" <th>ethnicity</th>\n",
" <th>gender</th>\n",
" <th>frst_in_family</th>\n",
" <th>school_urbanicity</th>\n",
" <th>school_mindset</th>\n",
" <th>school_achievement</th>\n",
" <th>school_ethnic_minority</th>\n",
" <th>school_poverty</th>\n",
" <th>school_size</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>76</td>\n",
" <td>1</td>\n",
" <td>0.277359</td>\n",
" <td>6</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>0.334544</td>\n",
" <td>0.648586</td>\n",
" <td>-1.310927</td>\n",
" <td>0.224077</td>\n",
" <td>-0.426757</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>76</td>\n",
" <td>1</td>\n",
" <td>-0.449646</td>\n",
" <td>4</td>\n",
" <td>12</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>0.334544</td>\n",
" <td>0.648586</td>\n",
" <td>-1.310927</td>\n",
" <td>0.224077</td>\n",
" <td>-0.426757</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>76</td>\n",
" <td>1</td>\n",
" <td>0.769703</td>\n",
" <td>6</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>0.334544</td>\n",
" <td>0.648586</td>\n",
" <td>-1.310927</td>\n",
" <td>0.224077</td>\n",
" <td>-0.426757</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>76</td>\n",
" <td>1</td>\n",
" <td>-0.121763</td>\n",
" <td>6</td>\n",
" <td>4</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>0.334544</td>\n",
" <td>0.648586</td>\n",
" <td>-1.310927</td>\n",
" <td>0.224077</td>\n",
" <td>-0.426757</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>76</td>\n",
" <td>1</td>\n",
" <td>1.526147</td>\n",
" <td>6</td>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>0.334544</td>\n",
" <td>0.648586</td>\n",
" <td>-1.310927</td>\n",
" <td>0.224077</td>\n",
" <td>-0.426757</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" schoolid intervention achievement_score success_expect ethnicity \\\n",
"0 76 1 0.277359 6 4 \n",
"1 76 1 -0.449646 4 12 \n",
"2 76 1 0.769703 6 4 \n",
"3 76 1 -0.121763 6 4 \n",
"4 76 1 1.526147 6 4 \n",
"\n",
" gender frst_in_family school_urbanicity school_mindset \\\n",
"0 2 1 4 0.334544 \n",
"1 2 1 4 0.334544 \n",
"2 2 0 4 0.334544 \n",
"3 2 0 4 0.334544 \n",
"4 1 0 4 0.334544 \n",
"\n",
" school_achievement school_ethnic_minority school_poverty school_size \n",
"0 0.648586 -1.310927 0.224077 -0.426757 \n",
"1 0.648586 -1.310927 0.224077 -0.426757 \n",
"2 0.648586 -1.310927 0.224077 -0.426757 \n",
"3 0.648586 -1.310927 0.224077 -0.426757 \n",
"4 0.648586 -1.310927 0.224077 -0.426757 "
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

각 변수에 대한 간략한 설명은 없어도 될까요?

저도 같은 데이터셋을 사용해서 해창님 부분에 각 컬럼에 대한 설명이 있으면 이 데이터셋을 처음 보는 사람들이 결과를 해석하기 더 편할 듯 합니다!

},
"execution_count": 165,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"data = pd.read_csv(\"../data/matheus_data/learning_mindset.csv\")\n",
"\n",
"data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Propensity Score Estimation"
]
},
{
"cell_type": "code",
"execution_count": 205,
"metadata": {},
"outputs": [],
"source": [
"from scipy.special import logit\n",
"import numpy as np\n",
"from causalml.propensity import ElasticNetPropensityModel\n",
"\n",
"categ = [\"ethnicity\",\"gender\",\"school_urbanicity\"]\n",
"cont = [\"school_mindset\",\"school_achievement\",\"school_ethnic_minority\",\"school_poverty\",\"school_size\"]\n",
"X = pd.get_dummies(data[categ + cont], columns=categ, drop_first=True)\n",
"\n",
"pm = ElasticNetPropensityModel(\n",
" random_state=42,\n",
" max_iter=5000\n",
")\n",
"ps = pm.fit_predict(X.values, data[\"intervention\"].values)\n",
"logit_ps = logit(ps)\n",
"zlogit_ps = (logit_ps - logit_ps.mean()) / logit_ps.std(ddof=1)\n",
"\n",
"df = data.copy()\n",
"df[\"ps_logit_z\"] = zlogit_ps"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- **logit(PS) 표준편차**: Propensity Score를 logit 변환 후 표준화하여, 매칭 시 caliper 거리 기준에 활용\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Matching - ATT"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ATT: 0.4983\n"
]
}
],
"source": [
"from causalml.match import NearestNeighborMatch\n",
"\n",
"df_match = pd.concat([df[[\"intervention\", 'achievement_score', \"ps_logit_z\"]], X], axis=1)\n",
"\n",
"matcher_att = NearestNeighborMatch(\n",
" caliper=0.2,\n",
" replace=False,\n",
" ratio=1,\n",
" shuffle=False,\n",
" random_state=42,\n",
" treatment_to_control=True\n",
")\n",
"\n",
"matched_att = matcher_att.match(\n",
" data=df_match, treatment_col=\"intervention\", score_cols=[\"ps_logit_z\"]\n",
")\n",
"\n",
"ATT = (\n",
" matched_att.loc[matched_att[\"intervention\"]==1, 'achievement_score'].mean() \n",
" - matched_att.loc[matched_att[\"intervention\"]==0, 'achievement_score'].mean()\n",
")\n",
"print(\"ATT:\", round(ATT, 4))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- **caliper=0.2**: logit(PS) 표준편차의 20% 이내에서만 매칭\n",
"- **replace=False**: 비복원 매칭으로 동일 대조군 중복 사용 방지"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Balance Check - ATT"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Matching - ATC"
]
},
{
"cell_type": "code",
"execution_count": 207,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ATC: 0.3843\n"
]
}
],
"source": [
"from causalml.match import NearestNeighborMatch\n",
"\n",
"matcher_atc = NearestNeighborMatch(\n",
" caliper=0.2,\n",
" replace=True,\n",
" ratio=1,\n",
" shuffle=False,\n",
" random_state=42,\n",
" treatment_to_control=False\n",
")\n",
"\n",
"matched_atc = matcher_atc.match(\n",
" data=df_match, treatment_col=\"intervention\", score_cols=[\"ps_logit_z\"]\n",
")\n",
"\n",
"ATC = (\n",
" matched_atc.loc[matched_atc[\"intervention\"]==1, \"achievement_score\"].mean()\n",
" - matched_atc.loc[matched_atc[\"intervention\"]==0, \"achievement_score\"].mean() \n",
")\n",
"print(\"ATC:\", round(ATC, 4))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- **replace=True**: 후보군 부족으로 복원 매칭 허용\n",
" - 매칭 실패 시 caliper 완화, 복원 매칭, k:1 매칭 확대로 대응"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

복원 매칭 허용된다는 뜻은 Propensity Score의 수렴 여부와는 상관이 없는 개념이 맞을까요?
만약 매칭 단계에서의 복원추출 허용이라고 한다면 하나의 변수가 과도하게 많은 매칭에 사용되는 등의 문제가 있을 듯 한데, 이런 문제가 발생했는지 여부를 어떤 방식으로 확인할 수 있는지 궁금합니다!

Copy link
Contributor Author

@Funbucket Funbucket Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. replace와 PS 수렴
    맞습니다! caliper, replace, ratio는 PS를 이미 추정한 뒤 매칭 단계에서 사용하는 파라미터이고, PS 모형의 수렴 여부와는 무관합니다.

  2. 복원 매칭 시 과도한 재사용 진단
    말씀해주신 대로 항상 bias–variance trade-off를 고려해야 할 것 같습니다. 아래 두가지 방법이 가능할 것 같습니다!

    • 단순 확인: 대조군별 매칭 횟수 체크
    • 정량 확인: ESS/원표본 비율 평가 → 비율이 줄면 소수 유닛 과도하게 의존하는 것으로 해석
  3. 추가 답변
    가장 신뢰할 수 있는 평가는 신뢰구간인 것 같은데요, R의 Matching 패키지에는 Abadie–Imbens 표준오차 추정이 구현돼 있지만, Python에서는 부트스트랩이 유효하지 않아 관련 구현이 부족한 상황이라 향후 과제로 보입니다!

]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Balance Check - ATC"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ATE"
]
},
{
"cell_type": "code",
"execution_count": 208,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ATE: 0.4214\n"
]
}
],
"source": [
"p_t = float(df[\"intervention\"].mean())\n",
"ATE = ATT * p_t + ATC * (1 - p_t)\n",
"print(\"ATE:\", round(ATE, 4))"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

코드 구현 깔끔하게 정리해주셔서 감사드립니다!

코드 구현 후에 결과에 대한 짤막한 해석이 있으면 더욱 좋을 듯 합니다!!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

리뷰 꼼꼼하게 해주셔서 감사합니다! 해석 부분 좀 더 신경 쓰도록 하겠습니다!

]
},
{
Expand All @@ -29,11 +373,16 @@
" async>\n",
"</script>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"display_name": ".venv312",
"language": "python",
"name": "python3"
},
Expand All @@ -47,7 +396,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down