15
15
Difference in differences
16
16
"""
17
17
18
+ import arviz as az
18
19
import numpy as np
19
20
import pandas as pd
21
+ import seaborn as sns
20
22
from matplotlib import pyplot as plt
21
23
from patsy import build_design_matrices , dmatrices
22
24
23
25
from causalpy .custom_exceptions import (
24
26
DataException ,
25
27
FormulaException ,
26
28
)
29
+ from causalpy .plot_utils import plot_xY
27
30
from causalpy .pymc_models import PyMCModel
28
31
from causalpy .skl_models import ScikitLearnModel
29
- from causalpy .utils import _is_variable_dummy_coded , convert_to_string
32
+ from causalpy .utils import _is_variable_dummy_coded , convert_to_string , round_num
30
33
31
34
from .base import BaseExperiment
32
35
36
+ LEGEND_FONT_SIZE = 12
37
+
33
38
34
39
class DifferenceInDifferences (BaseExperiment ):
35
40
"""A class to analyse data from Difference in Difference settings.
@@ -205,18 +210,6 @@ def input_validation(self):
205
210
coded. Consisting of 0's and 1's only."""
206
211
)
207
212
208
- def plot (self , round_to = None ) -> tuple [plt .Figure , plt .Axes ]:
209
- """
210
- Plot the results
211
-
212
- :param round_to:
213
- Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
214
- """
215
- # Get a BayesianPlotComponent or OLSPlotComponent depending on the model
216
- plot_component = self .model .get_plot_component ()
217
- fig , ax = plot_component .plot_difference_in_differences (self , round_to = round_to )
218
- return fig , ax
219
-
220
213
def summary (self , round_to = None ) -> None :
221
214
"""Print summary of main results and model coefficients.
222
215
@@ -232,3 +225,216 @@ def summary(self, round_to=None) -> None:
232
225
def _causal_impact_summary_stat (self , round_to = None ) -> str :
233
226
"""Computes the mean and 94% credible interval bounds for the causal impact."""
234
227
return f"Causal impact = { convert_to_string (self .causal_impact , round_to = round_to )} "
228
+
229
+ def bayesian_plot (self , round_to = None , ** kwargs ) -> tuple [plt .Figure , plt .Axes ]:
230
+ """
231
+ Plot the results
232
+
233
+ :param round_to:
234
+ Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
235
+ """
236
+ round_to = kwargs .get ("round_to" )
237
+
238
+ def _plot_causal_impact_arrow (results , ax ):
239
+ """
240
+ draw a vertical arrow between `y_pred_counterfactual` and
241
+ `y_pred_counterfactual`
242
+ """
243
+ # Calculate y values to plot the arrow between
244
+ y_pred_treatment = (
245
+ results .y_pred_treatment ["posterior_predictive" ]
246
+ .mu .isel ({"obs_ind" : 1 })
247
+ .mean ()
248
+ .data
249
+ )
250
+ y_pred_counterfactual = (
251
+ results .y_pred_counterfactual ["posterior_predictive" ].mu .mean ().data
252
+ )
253
+ # Calculate the x position to plot at
254
+ # Note that we force to be float to avoid a type error using np.ptp with boolean
255
+ # values
256
+ diff = np .ptp (
257
+ np .array (
258
+ results .x_pred_treatment [results .time_variable_name ].values
259
+ ).astype (float )
260
+ )
261
+ x = (
262
+ np .max (results .x_pred_treatment [results .time_variable_name ].values )
263
+ + 0.1 * diff
264
+ )
265
+ # Plot the arrow
266
+ ax .annotate (
267
+ "" ,
268
+ xy = (x , y_pred_counterfactual ),
269
+ xycoords = "data" ,
270
+ xytext = (x , y_pred_treatment ),
271
+ textcoords = "data" ,
272
+ arrowprops = {"arrowstyle" : "<-" , "color" : "green" , "lw" : 3 },
273
+ )
274
+ # Plot text annotation next to arrow
275
+ ax .annotate (
276
+ "causal\n impact" ,
277
+ xy = (x , np .mean ([y_pred_counterfactual , y_pred_treatment ])),
278
+ xycoords = "data" ,
279
+ xytext = (5 , 0 ),
280
+ textcoords = "offset points" ,
281
+ color = "green" ,
282
+ va = "center" ,
283
+ )
284
+
285
+ fig , ax = plt .subplots ()
286
+
287
+ # Plot raw data
288
+ sns .scatterplot (
289
+ self .data ,
290
+ x = self .time_variable_name ,
291
+ y = self .outcome_variable_name ,
292
+ hue = self .group_variable_name ,
293
+ alpha = 1 ,
294
+ legend = False ,
295
+ markers = True ,
296
+ ax = ax ,
297
+ )
298
+
299
+ # Plot model fit to control group
300
+ time_points = self .x_pred_control [self .time_variable_name ].values
301
+ h_line , h_patch = plot_xY (
302
+ time_points ,
303
+ self .y_pred_control .posterior_predictive .mu ,
304
+ ax = ax ,
305
+ plot_hdi_kwargs = {"color" : "C0" },
306
+ label = "Control group" ,
307
+ )
308
+ handles = [(h_line , h_patch )]
309
+ labels = ["Control group" ]
310
+
311
+ # Plot model fit to treatment group
312
+ time_points = self .x_pred_control [self .time_variable_name ].values
313
+ h_line , h_patch = plot_xY (
314
+ time_points ,
315
+ self .y_pred_treatment .posterior_predictive .mu ,
316
+ ax = ax ,
317
+ plot_hdi_kwargs = {"color" : "C1" },
318
+ label = "Treatment group" ,
319
+ )
320
+ handles .append ((h_line , h_patch ))
321
+ labels .append ("Treatment group" )
322
+
323
+ # Plot counterfactual - post-test for treatment group IF no treatment
324
+ # had occurred.
325
+ time_points = self .x_pred_counterfactual [self .time_variable_name ].values
326
+ if len (time_points ) == 1 :
327
+ parts = ax .violinplot (
328
+ az .extract (
329
+ self .y_pred_counterfactual ,
330
+ group = "posterior_predictive" ,
331
+ var_names = "mu" ,
332
+ ).values .T ,
333
+ positions = self .x_pred_counterfactual [self .time_variable_name ].values ,
334
+ showmeans = False ,
335
+ showmedians = False ,
336
+ widths = 0.2 ,
337
+ )
338
+ for pc in parts ["bodies" ]:
339
+ pc .set_facecolor ("C0" )
340
+ pc .set_edgecolor ("None" )
341
+ pc .set_alpha (0.5 )
342
+ else :
343
+ h_line , h_patch = plot_xY (
344
+ time_points ,
345
+ self .y_pred_counterfactual .posterior_predictive .mu ,
346
+ ax = ax ,
347
+ plot_hdi_kwargs = {"color" : "C2" },
348
+ label = "Counterfactual" ,
349
+ )
350
+ handles .append ((h_line , h_patch ))
351
+ labels .append ("Counterfactual" )
352
+
353
+ # arrow to label the causal impact
354
+ _plot_causal_impact_arrow (self , ax )
355
+
356
+ # formatting
357
+ ax .set (
358
+ xticks = self .x_pred_treatment [self .time_variable_name ].values ,
359
+ title = self ._causal_impact_summary_stat (round_to ),
360
+ )
361
+ ax .legend (
362
+ handles = (h_tuple for h_tuple in handles ),
363
+ labels = labels ,
364
+ fontsize = LEGEND_FONT_SIZE ,
365
+ )
366
+ return fig , ax
367
+
368
+ def ols_plot (self , round_to = None , ** kwargs ) -> tuple [plt .Figure , plt .Axes ]:
369
+ """Generate plot for difference-in-differences"""
370
+ round_to = kwargs .get ("round_to" )
371
+ fig , ax = plt .subplots ()
372
+
373
+ # Plot raw data
374
+ sns .lineplot (
375
+ self .data ,
376
+ x = self .time_variable_name ,
377
+ y = self .outcome_variable_name ,
378
+ hue = "group" ,
379
+ units = "unit" ,
380
+ estimator = None ,
381
+ alpha = 0.25 ,
382
+ ax = ax ,
383
+ )
384
+ # Plot model fit to control group
385
+ ax .plot (
386
+ self .x_pred_control [self .time_variable_name ],
387
+ self .y_pred_control ,
388
+ "o" ,
389
+ c = "C0" ,
390
+ markersize = 10 ,
391
+ label = "model fit (control group)" ,
392
+ )
393
+ # Plot model fit to treatment group
394
+ ax .plot (
395
+ self .x_pred_treatment [self .time_variable_name ],
396
+ self .y_pred_treatment ,
397
+ "o" ,
398
+ c = "C1" ,
399
+ markersize = 10 ,
400
+ label = "model fit (treament group)" ,
401
+ )
402
+ # Plot counterfactual - post-test for treatment group IF no treatment
403
+ # had occurred.
404
+ ax .plot (
405
+ self .x_pred_counterfactual [self .time_variable_name ],
406
+ self .y_pred_counterfactual ,
407
+ "go" ,
408
+ markersize = 10 ,
409
+ label = "counterfactual" ,
410
+ )
411
+ # arrow to label the causal impact
412
+ ax .annotate (
413
+ "" ,
414
+ xy = (1.05 , self .y_pred_counterfactual ),
415
+ xycoords = "data" ,
416
+ xytext = (1.05 , self .y_pred_treatment [1 ]),
417
+ textcoords = "data" ,
418
+ arrowprops = {"arrowstyle" : "<->" , "color" : "green" , "lw" : 3 },
419
+ )
420
+ ax .annotate (
421
+ "causal\n impact" ,
422
+ xy = (
423
+ 1.05 ,
424
+ np .mean ([self .y_pred_counterfactual [0 ], self .y_pred_treatment [1 ]]),
425
+ ),
426
+ xycoords = "data" ,
427
+ xytext = (5 , 0 ),
428
+ textcoords = "offset points" ,
429
+ color = "green" ,
430
+ va = "center" ,
431
+ )
432
+ # formatting
433
+ ax .set (
434
+ xlim = [- 0.05 , 1.1 ],
435
+ xticks = [0 , 1 ],
436
+ xticklabels = ["pre" , "post" ],
437
+ title = f"Causal impact = { round_num (self .causal_impact , round_to )} " ,
438
+ )
439
+ ax .legend (fontsize = LEGEND_FONT_SIZE )
440
+ return fig , ax
0 commit comments