Skip to content

Commit 77dbbe9

Browse files
authored
Merge pull request #153 from pymc-labs/improve-data-import
Idiomatic data import and processing with pandas method chaining for all examples
2 parents 16f925b + 295de05 commit 77dbbe9

File tree

7 files changed

+807
-794
lines changed

7 files changed

+807
-794
lines changed

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,11 @@ def test_rd_drinking():
141141

142142
@pytest.mark.integration
143143
def test_its():
144-
df = cp.load_data("its")
145-
df["date"] = pd.to_datetime(df["date"])
146-
df.set_index("date", inplace=True)
144+
df = (
145+
cp.load_data("its")
146+
.assign(date=lambda x: pd.to_datetime(x["date"]))
147+
.set_index("date")
148+
)
147149
treatment_time = pd.to_datetime("2017-01-01")
148150
result = cp.pymc_experiments.SyntheticControl(
149151
df,
@@ -159,9 +161,11 @@ def test_its():
159161

160162
@pytest.mark.integration
161163
def test_its_covid():
162-
df = cp.load_data("covid")
163-
df["date"] = pd.to_datetime(df["date"])
164-
df = df.set_index("date")
164+
df = (
165+
cp.load_data("covid")
166+
.assign(date=lambda x: pd.to_datetime(x["date"]))
167+
.set_index("date")
168+
)
165169
treatment_time = pd.to_datetime("2020-01-01")
166170
result = cp.pymc_experiments.SyntheticControl(
167171
df,
@@ -193,12 +197,14 @@ def test_sc():
193197

194198
@pytest.mark.integration
195199
def test_sc_brexit():
196-
df = cp.load_data("brexit")
197-
df["Time"] = pd.to_datetime(df["Time"])
198-
df.set_index("Time", inplace=True)
199-
df = df.iloc[df.index > "2009", :]
200+
df = (
201+
cp.load_data("brexit")
202+
.assign(Time=lambda x: pd.to_datetime(x["Time"]))
203+
.set_index("Time")
204+
.loc[lambda x: x.index >= "2009-01-01"]
205+
.drop(["Japan", "Italy", "US", "Spain"], axis=1)
206+
)
200207
treatment_time = pd.to_datetime("2016 June 24")
201-
df = df.drop(["Japan", "Italy", "US", "Spain"], axis=1)
202208
target_country = "UK"
203209
all_countries = df.columns
204210
other_countries = all_countries.difference({target_country})
@@ -235,9 +241,11 @@ def test_ancova():
235241

236242
@pytest.mark.integration
237243
def test_geolift1():
238-
df = cp.load_data("geolift1")
239-
df["time"] = pd.to_datetime(df["time"])
240-
df.set_index("time", inplace=True)
244+
df = (
245+
cp.load_data("geolift1")
246+
.assign(time=lambda x: pd.to_datetime(x["time"]))
247+
.set_index("time")
248+
)
241249
treatment_time = pd.to_datetime("2022-01-01")
242250
result = cp.pymc_experiments.SyntheticControl(
243251
df,

causalpy/tests/test_integration_skl_examples.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ def test_rd_drinking():
4343

4444
@pytest.mark.integration
4545
def test_its():
46-
df = cp.load_data("its")
47-
df["date"] = pd.to_datetime(df["date"])
48-
df.set_index("date", inplace=True)
46+
df = (
47+
cp.load_data("its")
48+
.assign(date=lambda x: pd.to_datetime(x["date"]))
49+
.set_index("date")
50+
)
4951
treatment_time = pd.to_datetime("2017-01-01")
5052
result = cp.skl_experiments.SyntheticControl(
5153
df,

docs/notebooks/geolift1.ipynb

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,12 @@
181181
}
182182
],
183183
"source": [
184-
"df = cp.load_data(\"geolift1\")\n",
185-
"# convert time column to datetime format\n",
186-
"df[\"time\"] = pd.to_datetime(df[\"time\"])\n",
187-
"df.set_index(\"time\", inplace=True)\n",
188-
"# define the treamtment time\n",
184+
"df = (\n",
185+
" cp.load_data(\"geolift1\")\n",
186+
" .assign(time=lambda x: pd.to_datetime(x[\"time\"]))\n",
187+
" .set_index(\"time\")\n",
188+
")\n",
189+
"\n",
189190
"treatment_time = pd.to_datetime(\"2022-01-01\")\n",
190191
"df.head()"
191192
]
@@ -589,7 +590,7 @@
589590
],
590591
"metadata": {
591592
"kernelspec": {
592-
"display_name": "Python 3.10.8 ('CausalPy')",
593+
"display_name": "CausalPy",
593594
"language": "python",
594595
"name": "python3"
595596
},
@@ -603,12 +604,12 @@
603604
"name": "python",
604605
"nbconvert_exporter": "python",
605606
"pygments_lexer": "ipython3",
606-
"version": "3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:25:29) [Clang 14.0.6 ]"
607+
"version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]"
607608
},
608609
"orig_nbformat": 4,
609610
"vscode": {
610611
"interpreter": {
611-
"hash": "46d31859cc45aa26a1223a391e7cf3023d69984b498bed11e66c690302b7e251"
612+
"hash": "02f5385db19eab57520277c5168790c7855381ee953bdbb5c89c321e1f17586e"
612613
}
613614
}
614615
},

docs/notebooks/its_covid.ipynb

Lines changed: 10 additions & 7 deletions
Large diffs are not rendered by default.

docs/notebooks/sc2_pymc.ipynb

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,12 @@
142142
}
143143
],
144144
"source": [
145-
"df = cp.load_data(\"its\")\n",
146-
"df[\"date\"] = pd.to_datetime(df[\"date\"])\n",
147-
"df.set_index(\"date\", inplace=True)\n",
145+
"df = (\n",
146+
" cp.load_data(\"its\")\n",
147+
" .assign(date=lambda x: pd.to_datetime(x[\"date\"]))\n",
148+
" .set_index(\"date\")\n",
149+
")\n",
150+
"\n",
148151
"treatment_time = pd.to_datetime(\"2017-01-01\")\n",
149152
"df.head()"
150153
]
@@ -228,7 +231,7 @@
228231
"name": "stderr",
229232
"output_type": "stream",
230233
"text": [
231-
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.\n",
234+
"Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.\n",
232235
"Sampling: [beta, sigma, y_hat]\n",
233236
"Sampling: [y_hat]\n",
234237
"Sampling: [y_hat]\n",
@@ -498,7 +501,7 @@
498501
],
499502
"metadata": {
500503
"kernelspec": {
501-
"display_name": "Python 3.10.8 ('CausalPy')",
504+
"display_name": "CausalPy",
502505
"language": "python",
503506
"name": "python3"
504507
},
@@ -512,12 +515,12 @@
512515
"name": "python",
513516
"nbconvert_exporter": "python",
514517
"pygments_lexer": "ipython3",
515-
"version": "3.10.8"
518+
"version": "3.10.6"
516519
},
517520
"orig_nbformat": 4,
518521
"vscode": {
519522
"interpreter": {
520-
"hash": "46d31859cc45aa26a1223a391e7cf3023d69984b498bed11e66c690302b7e251"
523+
"hash": "02f5385db19eab57520277c5168790c7855381ee953bdbb5c89c321e1f17586e"
521524
}
522525
}
523526
},

docs/notebooks/sc2_skl.ipynb

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,12 @@
123123
}
124124
],
125125
"source": [
126-
"df = cp.load_data(\"its\")\n",
127-
"df[\"date\"] = pd.to_datetime(df[\"date\"])\n",
128-
"df.set_index(\"date\", inplace=True)\n",
126+
"df = (\n",
127+
" cp.load_data(\"its\")\n",
128+
" .assign(date=lambda x: pd.to_datetime(x[\"date\"]))\n",
129+
" .set_index(\"date\")\n",
130+
")\n",
131+
"\n",
129132
"treatment_time = pd.to_datetime(\"2017-01-01\")\n",
130133
"df.head()"
131134
]
@@ -188,7 +191,7 @@
188191
],
189192
"metadata": {
190193
"kernelspec": {
191-
"display_name": "Python 3.10.6 ('CausalPy')",
194+
"display_name": "CausalPy",
192195
"language": "python",
193196
"name": "python3"
194197
},
@@ -207,7 +210,7 @@
207210
"orig_nbformat": 4,
208211
"vscode": {
209212
"interpreter": {
210-
"hash": "46d31859cc45aa26a1223a391e7cf3023d69984b498bed11e66c690302b7e251"
213+
"hash": "02f5385db19eab57520277c5168790c7855381ee953bdbb5c89c321e1f17586e"
211214
}
212215
}
213216
},

docs/notebooks/sc_pymc_brexit.ipynb

Lines changed: 743 additions & 750 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)