1515Difference in differences
1616"""
1717
18+ from typing import Union
19+
1820import arviz as az
1921import numpy as np
2022import pandas as pd
@@ -47,20 +49,24 @@ class DifferenceInDifferences(BaseExperiment):
4749
4850 .. note::
4951
50- There is no pre/post intervention data distinction for DiD, we fit all the
51- data available.
52- :param data:
53- A pandas dataframe
54- :param formula:
55- A statistical model formula
56- :param time_variable_name:
57- Name of the data column for the time variable
58- :param group_variable_name:
59- Name of the data column for the group variable
60- :param post_treatment_variable_name:
61- Name of the data column indicating post-treatment period (default: "post_treatment")
62- :param model:
63- A PyMC model for difference in differences
52+ There is no pre/post intervention data distinction for DiD, we fit
53+ all the data available.
54+
55+ Parameters
56+ ----------
57+ data : pd.DataFrame
58+ A pandas dataframe.
59+ formula : str
60+ A statistical model formula.
61+ time_variable_name : str
62+ Name of the data column for the time variable.
63+ group_variable_name : str
64+ Name of the data column for the group variable.
65+ post_treatment_variable_name : str, optional
66+ Name of the data column indicating post-treatment period.
67+ Defaults to "post_treatment".
68+ model : PyMCModel or RegressorMixin, optional
69+ A PyMC model for difference in differences. Defaults to None.
6470
6571 Example
6672 --------
@@ -92,8 +98,8 @@ def __init__(
9298 time_variable_name : str ,
9399 group_variable_name : str ,
94100 post_treatment_variable_name : str = "post_treatment" ,
95- model = None ,
96- ** kwargs ,
101+ model : Union [ PyMCModel , RegressorMixin ] | None = None ,
102+ ** kwargs : dict ,
97103 ) -> None :
98104 super ().__init__ (model = model )
99105 self .causal_impact : xr .DataArray | float | None
@@ -234,14 +240,14 @@ def __init__(
234240 f"{ self .group_variable_name } :{ self .post_treatment_variable_name } "
235241 )
236242 matched_key = next ((k for k in coef_map if interaction_term in k ), None )
237- att = coef_map .get (matched_key )
243+ att = coef_map .get (matched_key ) if matched_key is not None else None
238244 self .causal_impact = att
239245 else :
240246 raise ValueError ("Model type not recognized" )
241247
242248 return
243249
244- def input_validation (self ):
250+ def input_validation (self ) -> None :
245251 # Validate formula structure and interaction interaction terms
246252 self ._validate_formula_interaction_terms ()
247253
@@ -269,7 +275,7 @@ def input_validation(self):
269275 coded. Consisting of 0's and 1's only."""
270276 )
271277
272- def _validate_formula_interaction_terms (self ):
278+ def _validate_formula_interaction_terms (self ) -> None :
273279 """
274280 Validate that the formula contains at most one interaction term and no three-way or higher-order interactions.
275281 Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables.
@@ -299,7 +305,7 @@ def _validate_formula_interaction_terms(self):
299305 "Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect."
300306 )
301307
302- def summary (self , round_to = None ) -> None :
308+ def summary (self , round_to : int | None = 2 ) -> None :
303309 """Print summary of main results and model coefficients.
304310
305311 :param round_to:
@@ -311,11 +317,13 @@ def summary(self, round_to=None) -> None:
311317 print (self ._causal_impact_summary_stat (round_to ))
312318 self .print_coefficients (round_to )
313319
314- def _causal_impact_summary_stat (self , round_to = None ) -> str :
320+ def _causal_impact_summary_stat (self , round_to : int | None = None ) -> str :
315321 """Computes the mean and 94% credible interval bounds for the causal impact."""
316322 return f"Causal impact = { convert_to_string (self .causal_impact , round_to = round_to )} "
317323
318- def _bayesian_plot (self , round_to = None , ** kwargs ) -> tuple [plt .Figure , plt .Axes ]:
324+ def _bayesian_plot (
325+ self , round_to : int | None = None , ** kwargs : dict
326+ ) -> tuple [plt .Figure , plt .Axes ]:
319327 """
320328 Plot the results
321329
@@ -463,9 +471,10 @@ def _plot_causal_impact_arrow(results, ax):
463471 )
464472 return fig , ax
465473
466- def _ols_plot (self , round_to = None , ** kwargs ) -> tuple [plt .Figure , plt .Axes ]:
474+ def _ols_plot (
475+ self , round_to : int | None = 2 , ** kwargs : dict
476+ ) -> tuple [plt .Figure , plt .Axes ]:
467477 """Generate plot for difference-in-differences"""
468- round_to = kwargs .get ("round_to" )
469478 fig , ax = plt .subplots ()
470479
471480 # Plot raw data
@@ -528,11 +537,15 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
528537 va = "center" ,
529538 )
530539 # formatting
540+ # In OLS context, causal_impact should be a float, but mypy doesn't know this
541+ causal_impact_value = (
542+ float (self .causal_impact ) if self .causal_impact is not None else 0.0
543+ )
531544 ax .set (
532545 xlim = [- 0.05 , 1.1 ],
533546 xticks = [0 , 1 ],
534547 xticklabels = ["pre" , "post" ],
535- title = f"Causal impact = { round_num (self . causal_impact , round_to )} " ,
548+ title = f"Causal impact = { round_num (causal_impact_value , round_to )} " ,
536549 )
537550 ax .legend (fontsize = LEGEND_FONT_SIZE )
538551 return fig , ax
0 commit comments