Skip to content

Commit 60dc6e2

Browse files
committed
update data validation
Signed-off-by: Nathaniel <[email protected]>
1 parent 994799b commit 60dc6e2

File tree

2 files changed

+55
-88
lines changed

2 files changed

+55
-88
lines changed

causalpy/data_validation.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,22 @@ class PropensityDataValidator:
140140

141141
def _input_validation(self):
142142
"""Validate the input data and model formula for correctness"""
143-
pass
143+
treatment = self.formula.split("~")[0]
144+
test = treatment.strip() in self.data.columns
145+
test = test & (self.outcome_variable in self.data.columns)
146+
if not test:
147+
raise DataException(
148+
f"""
149+
The treatment variable:
150+
{treatment} must appear in the data to be used
151+
as an outcome variable. And {self.outcome_variable}
152+
must also be available in the data to be re-weighted
153+
"""
154+
)
155+
T = self.data[treatment.strip()]
156+
check_binary = len(np.unique(T)) > 2
157+
if check_binary:
158+
raise DataException(
159+
"""Warning. The treatment variable is not 0-1 Binary.
160+
"""
161+
)

docs/source/notebooks/inv_prop_pymc.ipynb

Lines changed: 36 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 357,
5+
"execution_count": 378,
66
"metadata": {},
77
"outputs": [
88
{
@@ -50,7 +50,7 @@
5050
},
5151
{
5252
"cell_type": "code",
53-
"execution_count": 325,
53+
"execution_count": 376,
5454
"metadata": {},
5555
"outputs": [
5656
{
@@ -83,104 +83,53 @@
8383
" <tbody>\n",
8484
" <tr>\n",
8585
" <th>0</th>\n",
86-
" <td>1.374096</td>\n",
87-
" <td>0.373163</td>\n",
88-
" <td>1</td>\n",
89-
" <td>6.054919</td>\n",
86+
" <td>-0.700611</td>\n",
87+
" <td>0.215690</td>\n",
88+
" <td>0</td>\n",
89+
" <td>-1.060506</td>\n",
9090
" </tr>\n",
9191
" <tr>\n",
9292
" <th>1</th>\n",
93-
" <td>1.051587</td>\n",
94-
" <td>0.834493</td>\n",
95-
" <td>0</td>\n",
96-
" <td>2.927939</td>\n",
93+
" <td>0.880796</td>\n",
94+
" <td>1.082451</td>\n",
95+
" <td>1</td>\n",
96+
" <td>3.778433</td>\n",
9797
" </tr>\n",
9898
" <tr>\n",
9999
" <th>2</th>\n",
100-
" <td>-0.450553</td>\n",
101-
" <td>0.232016</td>\n",
100+
" <td>-0.121070</td>\n",
101+
" <td>0.767333</td>\n",
102102
" <td>0</td>\n",
103-
" <td>-0.043942</td>\n",
103+
" <td>0.617862</td>\n",
104104
" </tr>\n",
105105
" <tr>\n",
106106
" <th>3</th>\n",
107-
" <td>0.720264</td>\n",
108-
" <td>-0.539953</td>\n",
109-
" <td>0</td>\n",
110-
" <td>0.739484</td>\n",
111-
" </tr>\n",
112-
" <tr>\n",
113-
" <th>4</th>\n",
114-
" <td>0.778325</td>\n",
115-
" <td>1.534670</td>\n",
107+
" <td>0.149978</td>\n",
108+
" <td>1.146856</td>\n",
116109
" <td>1</td>\n",
117-
" <td>4.425341</td>\n",
118-
" </tr>\n",
119-
" <tr>\n",
120-
" <th>...</th>\n",
121-
" <td>...</td>\n",
122-
" <td>...</td>\n",
123-
" <td>...</td>\n",
124-
" <td>...</td>\n",
110+
" <td>2.831018</td>\n",
125111
" </tr>\n",
126112
" <tr>\n",
127-
" <th>9995</th>\n",
128-
" <td>0.890611</td>\n",
129-
" <td>1.266610</td>\n",
113+
" <th>4</th>\n",
114+
" <td>-0.506154</td>\n",
115+
" <td>0.113415</td>\n",
130116
" <td>0</td>\n",
131-
" <td>2.732242</td>\n",
132-
" </tr>\n",
133-
" <tr>\n",
134-
" <th>9996</th>\n",
135-
" <td>1.428810</td>\n",
136-
" <td>1.557557</td>\n",
137-
" <td>1</td>\n",
138-
" <td>5.068505</td>\n",
139-
" </tr>\n",
140-
" <tr>\n",
141-
" <th>9997</th>\n",
142-
" <td>1.678820</td>\n",
143-
" <td>1.254265</td>\n",
144-
" <td>1</td>\n",
145-
" <td>4.317824</td>\n",
146-
" </tr>\n",
147-
" <tr>\n",
148-
" <th>9998</th>\n",
149-
" <td>1.341190</td>\n",
150-
" <td>1.002567</td>\n",
151-
" <td>1</td>\n",
152-
" <td>4.527394</td>\n",
153-
" </tr>\n",
154-
" <tr>\n",
155-
" <th>9999</th>\n",
156-
" <td>1.330508</td>\n",
157-
" <td>0.702635</td>\n",
158-
" <td>1</td>\n",
159-
" <td>2.982631</td>\n",
117+
" <td>-0.106079</td>\n",
160118
" </tr>\n",
161119
" </tbody>\n",
162120
"</table>\n",
163-
"<p>10000 rows × 4 columns</p>\n",
164121
"</div>"
165122
],
166123
"text/plain": [
167-
" x1 x2 trt outcome\n",
168-
"0 1.374096 0.373163 1 6.054919\n",
169-
"1 1.051587 0.834493 0 2.927939\n",
170-
"2 -0.450553 0.232016 0 -0.043942\n",
171-
"3 0.720264 -0.539953 0 0.739484\n",
172-
"4 0.778325 1.534670 1 4.425341\n",
173-
"... ... ... ... ...\n",
174-
"9995 0.890611 1.266610 0 2.732242\n",
175-
"9996 1.428810 1.557557 1 5.068505\n",
176-
"9997 1.678820 1.254265 1 4.317824\n",
177-
"9998 1.341190 1.002567 1 4.527394\n",
178-
"9999 1.330508 0.702635 1 2.982631\n",
179-
"\n",
180-
"[10000 rows x 4 columns]"
124+
" x1 x2 trt outcome\n",
125+
"0 -0.700611 0.215690 0 -1.060506\n",
126+
"1 0.880796 1.082451 1 3.778433\n",
127+
"2 -0.121070 0.767333 0 0.617862\n",
128+
"3 0.149978 1.146856 1 2.831018\n",
129+
"4 -0.506154 0.113415 0 -0.106079"
181130
]
182131
},
183-
"execution_count": 325,
132+
"execution_count": 376,
184133
"metadata": {},
185134
"output_type": "execute_result"
186135
}
@@ -189,7 +138,7 @@
189138
"df1 = pd.DataFrame(np.random.multivariate_normal([0.5, 1], [[2, 1], [1, 1]], size=10000), columns=['x1', 'x2'])\n",
190139
"df1['trt'] = np.where(-0.5 + 0.25 * df1['x1'] + 0.75 * df1['x2'] + np.random.normal(0, 1, size=10000) > 0, 1, 0)\n",
191140
"df1['outcome'] = 2 * df1['trt'] + df1['x1'] + df1['x2'] + np.random.normal(0, 1, size=10000)\n",
192-
"df1"
141+
"df1.head()"
193142
]
194143
},
195144
{
@@ -208,7 +157,7 @@
208157
},
209158
{
210159
"cell_type": "code",
211-
"execution_count": 338,
160+
"execution_count": 379,
212161
"metadata": {},
213162
"outputs": [
214163
{
@@ -227,10 +176,10 @@
227176
{
228177
"data": {
229178
"text/plain": [
230-
"<causalpy.pymc_experiments.InversePropensityWeighting at 0x2aebe6110>"
179+
"<causalpy.pymc_experiments.InversePropensityWeighting at 0x32412ee50>"
231180
]
232181
},
233-
"execution_count": 338,
182+
"execution_count": 379,
234183
"metadata": {},
235184
"output_type": "execute_result"
236185
}
@@ -878,7 +827,7 @@
878827
},
879828
{
880829
"cell_type": "code",
881-
"execution_count": 373,
830+
"execution_count": 380,
882831
"metadata": {},
883832
"outputs": [
884833
{
@@ -969,7 +918,7 @@
969918
"4 40 0 0 20 19 4.989251"
970919
]
971920
},
972-
"execution_count": 373,
921+
"execution_count": 380,
973922
"metadata": {},
974923
"output_type": "execute_result"
975924
}
@@ -981,7 +930,7 @@
981930
},
982931
{
983932
"cell_type": "code",
984-
"execution_count": 365,
933+
"execution_count": 381,
985934
"metadata": {},
986935
"outputs": [
987936
{
@@ -1000,10 +949,10 @@
1000949
{
1001950
"data": {
1002951
"text/plain": [
1003-
"<causalpy.pymc_experiments.InversePropensityWeighting at 0x2e6f4ba90>"
952+
"<causalpy.pymc_experiments.InversePropensityWeighting at 0x3bdbaa4d0>"
1004953
]
1005954
},
1006-
"execution_count": 365,
955+
"execution_count": 381,
1007956
"metadata": {},
1008957
"output_type": "execute_result"
1009958
}

0 commit comments

Comments
 (0)