-
Notifications
You must be signed in to change notification settings - Fork 5
feat: add PSM example — ATT/ATC/ATE #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: 3-awesome-causal-inference-pythonmainatepropensity_score_and_dml
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 이번 한 주도 고생 많으셨습니다! 덕분에 매칭 관련 공부 잘 하고 있어요 👍 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,9 +11,353 @@ | |
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "- Matching\n", | ||
| "- IPW, AIPW, Doubly Robust Estimator\n", | ||
| "- Double Machine Learning (비모수 버전의 Regression 처럼 활용 가능)" | ||
| "## Propensity Score Matching - Binary Treatment" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 165, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "data": { | ||
| "text/html": [ | ||
| "<div>\n", | ||
| "<style scoped>\n", | ||
| " .dataframe tbody tr th:only-of-type {\n", | ||
| " vertical-align: middle;\n", | ||
| " }\n", | ||
| "\n", | ||
| " .dataframe tbody tr th {\n", | ||
| " vertical-align: top;\n", | ||
| " }\n", | ||
| "\n", | ||
| " .dataframe thead th {\n", | ||
| " text-align: right;\n", | ||
| " }\n", | ||
| "</style>\n", | ||
| "<table border=\"1\" class=\"dataframe\">\n", | ||
| " <thead>\n", | ||
| " <tr style=\"text-align: right;\">\n", | ||
| " <th></th>\n", | ||
| " <th>schoolid</th>\n", | ||
| " <th>intervention</th>\n", | ||
| " <th>achievement_score</th>\n", | ||
| " <th>success_expect</th>\n", | ||
| " <th>ethnicity</th>\n", | ||
| " <th>gender</th>\n", | ||
| " <th>frst_in_family</th>\n", | ||
| " <th>school_urbanicity</th>\n", | ||
| " <th>school_mindset</th>\n", | ||
| " <th>school_achievement</th>\n", | ||
| " <th>school_ethnic_minority</th>\n", | ||
| " <th>school_poverty</th>\n", | ||
| " <th>school_size</th>\n", | ||
| " </tr>\n", | ||
| " </thead>\n", | ||
| " <tbody>\n", | ||
| " <tr>\n", | ||
| " <th>0</th>\n", | ||
| " <td>76</td>\n", | ||
| " <td>1</td>\n", | ||
| " <td>0.277359</td>\n", | ||
| " <td>6</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>2</td>\n", | ||
| " <td>1</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>0.334544</td>\n", | ||
| " <td>0.648586</td>\n", | ||
| " <td>-1.310927</td>\n", | ||
| " <td>0.224077</td>\n", | ||
| " <td>-0.426757</td>\n", | ||
| " </tr>\n", | ||
| " <tr>\n", | ||
| " <th>1</th>\n", | ||
| " <td>76</td>\n", | ||
| " <td>1</td>\n", | ||
| " <td>-0.449646</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>12</td>\n", | ||
| " <td>2</td>\n", | ||
| " <td>1</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>0.334544</td>\n", | ||
| " <td>0.648586</td>\n", | ||
| " <td>-1.310927</td>\n", | ||
| " <td>0.224077</td>\n", | ||
| " <td>-0.426757</td>\n", | ||
| " </tr>\n", | ||
| " <tr>\n", | ||
| " <th>2</th>\n", | ||
| " <td>76</td>\n", | ||
| " <td>1</td>\n", | ||
| " <td>0.769703</td>\n", | ||
| " <td>6</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>2</td>\n", | ||
| " <td>0</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>0.334544</td>\n", | ||
| " <td>0.648586</td>\n", | ||
| " <td>-1.310927</td>\n", | ||
| " <td>0.224077</td>\n", | ||
| " <td>-0.426757</td>\n", | ||
| " </tr>\n", | ||
| " <tr>\n", | ||
| " <th>3</th>\n", | ||
| " <td>76</td>\n", | ||
| " <td>1</td>\n", | ||
| " <td>-0.121763</td>\n", | ||
| " <td>6</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>2</td>\n", | ||
| " <td>0</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>0.334544</td>\n", | ||
| " <td>0.648586</td>\n", | ||
| " <td>-1.310927</td>\n", | ||
| " <td>0.224077</td>\n", | ||
| " <td>-0.426757</td>\n", | ||
| " </tr>\n", | ||
| " <tr>\n", | ||
| " <th>4</th>\n", | ||
| " <td>76</td>\n", | ||
| " <td>1</td>\n", | ||
| " <td>1.526147</td>\n", | ||
| " <td>6</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>1</td>\n", | ||
| " <td>0</td>\n", | ||
| " <td>4</td>\n", | ||
| " <td>0.334544</td>\n", | ||
| " <td>0.648586</td>\n", | ||
| " <td>-1.310927</td>\n", | ||
| " <td>0.224077</td>\n", | ||
| " <td>-0.426757</td>\n", | ||
| " </tr>\n", | ||
| " </tbody>\n", | ||
| "</table>\n", | ||
| "</div>" | ||
| ], | ||
| "text/plain": [ | ||
| " schoolid intervention achievement_score success_expect ethnicity \\\n", | ||
| "0 76 1 0.277359 6 4 \n", | ||
| "1 76 1 -0.449646 4 12 \n", | ||
| "2 76 1 0.769703 6 4 \n", | ||
| "3 76 1 -0.121763 6 4 \n", | ||
| "4 76 1 1.526147 6 4 \n", | ||
| "\n", | ||
| " gender frst_in_family school_urbanicity school_mindset \\\n", | ||
| "0 2 1 4 0.334544 \n", | ||
| "1 2 1 4 0.334544 \n", | ||
| "2 2 0 4 0.334544 \n", | ||
| "3 2 0 4 0.334544 \n", | ||
| "4 1 0 4 0.334544 \n", | ||
| "\n", | ||
| " school_achievement school_ethnic_minority school_poverty school_size \n", | ||
| "0 0.648586 -1.310927 0.224077 -0.426757 \n", | ||
| "1 0.648586 -1.310927 0.224077 -0.426757 \n", | ||
| "2 0.648586 -1.310927 0.224077 -0.426757 \n", | ||
| "3 0.648586 -1.310927 0.224077 -0.426757 \n", | ||
| "4 0.648586 -1.310927 0.224077 -0.426757 " | ||
| ] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 각 변수에 대한 간략한 설명은 없어도 될까요? 저도 같은 데이터셋을 사용해서 해창님 부분에 각 컬럼에 대한 설명이 있으면 이 데이터셋을 처음 보는 사람들이 결과를 해석하기 더 편할 듯 합니다! |
||
| }, | ||
| "execution_count": 165, | ||
| "metadata": {}, | ||
| "output_type": "execute_result" | ||
| } | ||
| ], | ||
| "source": [ | ||
| "import pandas as pd\n", | ||
| "\n", | ||
| "data = pd.read_csv(\"../data/matheus_data/learning_mindset.csv\")\n", | ||
| "\n", | ||
| "data.head()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Propensity Score Estimation" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 205, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from scipy.special import logit\n", | ||
| "import numpy as np\n", | ||
| "from causalml.propensity import ElasticNetPropensityModel\n", | ||
| "\n", | ||
| "categ = [\"ethnicity\",\"gender\",\"school_urbanicity\"]\n", | ||
| "cont = [\"school_mindset\",\"school_achievement\",\"school_ethnic_minority\",\"school_poverty\",\"school_size\"]\n", | ||
| "X = pd.get_dummies(data[categ + cont], columns=categ, drop_first=True)\n", | ||
| "\n", | ||
| "pm = ElasticNetPropensityModel(\n", | ||
| " random_state=42,\n", | ||
| " max_iter=5000\n", | ||
| ")\n", | ||
| "ps = pm.fit_predict(X.values, data[\"intervention\"].values)\n", | ||
| "logit_ps = logit(ps)\n", | ||
| "zlogit_ps = (logit_ps - logit_ps.mean()) / logit_ps.std(ddof=1)\n", | ||
| "\n", | ||
| "df = data.copy()\n", | ||
| "df[\"ps_logit_z\"] = zlogit_ps" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "- **logit(PS) 표준편차**: Propensity Score를 logit 변환 후 표준화하여, 매칭 시 caliper 거리 기준에 활용\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Matching - ATT" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 206, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "ATT: 0.4983\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "from causalml.match import NearestNeighborMatch\n", | ||
| "\n", | ||
| "df_match = pd.concat([df[[\"intervention\", 'achievement_score', \"ps_logit_z\"]], X], axis=1)\n", | ||
| "\n", | ||
| "matcher_att = NearestNeighborMatch(\n", | ||
| " caliper=0.2,\n", | ||
| " replace=False,\n", | ||
| " ratio=1,\n", | ||
| " shuffle=False,\n", | ||
| " random_state=42,\n", | ||
| " treatment_to_control=True\n", | ||
| ")\n", | ||
| "\n", | ||
| "matched_att = matcher_att.match(\n", | ||
| " data=df_match, treatment_col=\"intervention\", score_cols=[\"ps_logit_z\"]\n", | ||
| ")\n", | ||
| "\n", | ||
| "ATT = (\n", | ||
| " matched_att.loc[matched_att[\"intervention\"]==1, 'achievement_score'].mean() \n", | ||
| " - matched_att.loc[matched_att[\"intervention\"]==0, 'achievement_score'].mean()\n", | ||
| ")\n", | ||
| "print(\"ATT:\", round(ATT, 4))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "- **caliper=0.2**: logit(PS) 표준편차의 20% 이내에서만 매칭\n", | ||
| "- **replace=False**: 비복원 매칭으로 동일 대조군 중복 사용 방지" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Balance Check - ATT" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Matching - ATC" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 207, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "ATC: 0.3843\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "from causalml.match import NearestNeighborMatch\n", | ||
| "\n", | ||
| "matcher_atc = NearestNeighborMatch(\n", | ||
| " caliper=0.2,\n", | ||
| " replace=True,\n", | ||
| " ratio=1,\n", | ||
| " shuffle=False,\n", | ||
| " random_state=42,\n", | ||
| " treatment_to_control=False\n", | ||
| ")\n", | ||
| "\n", | ||
| "matched_atc = matcher_atc.match(\n", | ||
| " data=df_match, treatment_col=\"intervention\", score_cols=[\"ps_logit_z\"]\n", | ||
| ")\n", | ||
| "\n", | ||
| "ATC = (\n", | ||
| " matched_atc.loc[matched_atc[\"intervention\"]==1, \"achievement_score\"].mean()\n", | ||
| " - matched_atc.loc[matched_atc[\"intervention\"]==0, \"achievement_score\"].mean() \n", | ||
| ")\n", | ||
| "print(\"ATC:\", round(ATC, 4))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "- **replace=True**: 후보군 부족으로 복원 매칭 허용\n", | ||
| " - 매칭 실패 시 caliper 완화, 복원 매칭, k:1 매칭 확대로 대응" | ||
|
||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Balance Check - ATC" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### ATE" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 208, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "ATE: 0.4214\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "p_t = float(df[\"intervention\"].mean())\n", | ||
| "ATE = ATT * p_t + ATC * (1 - p_t)\n", | ||
| "print(\"ATE:\", round(ATE, 4))" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 코드 구현 깔끔하게 정리해주셔서 감사드립니다! 코드 구현 후에 결과에 대한 짤막한 해석이 있으면 더욱 좋을 듯 합니다!! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 리뷰 꼼꼼하게 해주셔서 감사합니다! 해석 부분 좀 더 신경 쓰도록 하겠습니다! |
||
| ] | ||
| }, | ||
| { | ||
|
|
@@ -29,11 +373,16 @@ | |
| " async>\n", | ||
| "</script>" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "base", | ||
| "display_name": ".venv312", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
|
|
@@ -47,7 +396,7 @@ | |
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.12.7" | ||
| "version": "3.12.11" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
고생많으셨습니다! 다음 주 작업도 파이팅입니다 :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
리뷰 감사합니다!