Skip to content

Commit 3757199

Browse files
committed
add causal impact arrow back in to DID plot
1 parent 4442d5b commit 3757199

File tree

5 files changed

+160
-67
lines changed

5 files changed

+160
-67
lines changed

causalpy/plotting.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,53 @@ def plot_pre_post(results, round_to=None):
153153
def plot_difference_in_differences(results, round_to=None):
154154
"""Generate plot for difference-in-differences"""
155155

156+
def _plot_causal_impact_arrow(results, ax):
157+
"""
158+
draw a vertical arrow between `y_pred_counterfactual` and
159+
`y_pred_counterfactual`
160+
"""
161+
# Calculate y values to plot the arrow between
162+
y_pred_treatment = (
163+
results.y_pred_treatment["posterior_predictive"]
164+
.mu.isel({"obs_ind": 1})
165+
.mean()
166+
.data
167+
)
168+
y_pred_counterfactual = (
169+
results.y_pred_counterfactual["posterior_predictive"].mu.mean().data
170+
)
171+
# Calculate the x position to plot at
172+
# Note that we force to be float to avoid a type error using np.ptp with boolean
173+
# values
174+
diff = np.ptp(
175+
np.array(
176+
results.x_pred_treatment[results.time_variable_name].values
177+
).astype(float)
178+
)
179+
x = (
180+
np.max(results.x_pred_treatment[results.time_variable_name].values)
181+
+ 0.1 * diff
182+
)
183+
# Plot the arrow
184+
ax.annotate(
185+
"",
186+
xy=(x, y_pred_counterfactual),
187+
xycoords="data",
188+
xytext=(x, y_pred_treatment),
189+
textcoords="data",
190+
arrowprops={"arrowstyle": "<-", "color": "green", "lw": 3},
191+
)
192+
# Plot text annotation next to arrow
193+
ax.annotate(
194+
"causal\nimpact",
195+
xy=(x, np.mean([y_pred_counterfactual, y_pred_treatment])),
196+
xycoords="data",
197+
xytext=(5, 0),
198+
textcoords="offset points",
199+
color="green",
200+
va="center",
201+
)
202+
156203
fig, ax = plt.subplots()
157204

158205
# Plot raw data
@@ -224,7 +271,7 @@ def plot_difference_in_differences(results, round_to=None):
224271
labels.append("Counterfactual")
225272

226273
# arrow to label the causal impact
227-
# _plot_causal_impact_arrow(ax)
274+
_plot_causal_impact_arrow(results, ax)
228275

229276
# formatting
230277
ax.set(

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/notebooks/REFACTOR.ipynb

Lines changed: 89 additions & 43 deletions
Large diffs are not rendered by default.

docs/source/notebooks/did_pymc.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@
170170
{
171171
"data": {
172172
"application/vnd.jupyter.widget-view+json": {
173-
"model_id": "94cefb3fbe2b4e478ccd6dc92e35d78e",
173+
"model_id": "401ee07af4d94db28bd014b7986f5bcc",
174174
"version_major": 2,
175175
"version_minor": 0
176176
},
@@ -267,11 +267,11 @@
267267
"Results:\n",
268268
"Causal impact = 0.5, $CI_{94\\%}$[0.4, 0.6]\n",
269269
"Model coefficients:\n",
270-
" Intercept \t1.1, 94% HDI [1, 1.1]\n",
271-
" post_treatment[T.True] \t0.99, 94% HDI [0.92, 1.1]\n",
272-
" group \t0.16, 94% HDI [0.094, 0.23]\n",
273-
" group:post_treatment[T.True]\t0.5, 94% HDI [0.4, 0.6]\n",
274-
" sigma \t0.082, 94% HDI [0.066, 0.1]\n"
270+
" Intercept 1.1, 94% HDI [1, 1.1]\n",
271+
" post_treatment[T.True] 0.99, 94% HDI [0.92, 1.1]\n",
272+
" group 0.16, 94% HDI [0.094, 0.23]\n",
273+
" group:post_treatment[T.True] 0.5, 94% HDI [0.4, 0.6]\n",
274+
" sigma 0.082, 94% HDI [0.066, 0.1]\n"
275275
]
276276
}
277277
],

docs/source/notebooks/did_pymc_banks.ipynb

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
"name": "stderr",
122122
"output_type": "stream",
123123
"text": [
124-
"/var/folders/pd/p2qnky2x3xl4w3mgc4lct2200000gn/T/ipykernel_66134/4155710090.py:21: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
124+
"/var/folders/pd/p2qnky2x3xl4w3mgc4lct2200000gn/T/ipykernel_60207/4155710090.py:21: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
125125
" df_long = df_long.replace({\"district\": {\"Sixth District\": 1, \"Eighth District\": 0}})\n"
126126
]
127127
},
@@ -351,7 +351,7 @@
351351
{
352352
"data": {
353353
"application/vnd.jupyter.widget-view+json": {
354-
"model_id": "ebc1927741d644158177ba56c4b935fd",
354+
"model_id": "dab4f8e91a2846a58e652963cae66c50",
355355
"version_major": 2,
356356
"version_minor": 0
357357
},
@@ -470,11 +470,11 @@
470470
"Results:\n",
471471
"Causal impact = 19, $CI_{94\\%}$[15, 22]\n",
472472
"Model coefficients:\n",
473-
" Intercept \t165, 94% HDI [163, 167]\n",
474-
" post_treatment[T.True] \t-33, 94% HDI [-36, -30]\n",
475-
" district \t-30, 94% HDI [-32, -27]\n",
476-
" district:post_treatment[T.True]\t19, 94% HDI [15, 22]\n",
477-
" sigma \t0.84, 94% HDI [0.085, 2.2]\n"
473+
" Intercept 165, 94% HDI [163, 167]\n",
474+
" post_treatment[T.True] -33, 94% HDI [-36, -30]\n",
475+
" district -30, 94% HDI [-32, -27]\n",
476+
" district:post_treatment[T.True] 19, 94% HDI [15, 22]\n",
477+
" sigma 0.84, 94% HDI [0.085, 2.2]\n"
478478
]
479479
}
480480
],
@@ -555,7 +555,7 @@
555555
{
556556
"data": {
557557
"application/vnd.jupyter.widget-view+json": {
558-
"model_id": "1fb16fb3cb6845ef9e6ad49d1380ce8a",
558+
"model_id": "08a93ad293c443419024dd7390bb17af",
559559
"version_major": 2,
560560
"version_minor": 0
561561
},
@@ -654,12 +654,12 @@
654654
"Results:\n",
655655
"Causal impact = 20, $CI_{94\\%}$[15, 26]\n",
656656
"Model coefficients:\n",
657-
" Intercept \t160, 94% HDI [157, 164]\n",
658-
" post_treatment[T.True] \t-28, 94% HDI [-33, -22]\n",
659-
" year \t-7.1, 94% HDI [-8.5, -5.7]\n",
660-
" district \t-29, 94% HDI [-34, -24]\n",
661-
" district:post_treatment[T.True]\t20, 94% HDI [15, 26]\n",
662-
" sigma \t2.4, 94% HDI [1.7, 3.2]\n"
657+
" Intercept 160, 94% HDI [157, 164]\n",
658+
" post_treatment[T.True] -28, 94% HDI [-33, -22]\n",
659+
" year -7.1, 94% HDI [-8.5, -5.7]\n",
660+
" district -29, 94% HDI [-34, -24]\n",
661+
" district:post_treatment[T.True] 20, 94% HDI [15, 26]\n",
662+
" sigma 2.4, 94% HDI [1.7, 3.2]\n"
663663
]
664664
}
665665
],

0 commit comments

Comments
 (0)