2020from scipy .stats import dirichlet , gamma , norm , uniform
2121from statsmodels .nonparametric .smoothers_lowess import lowess
2222
23- default_lowess_kwargs = {"frac" : 0.2 , "it" : 0 }
24- RANDOM_SEED = 8927
25- rng = np .random .default_rng (RANDOM_SEED )
23+ default_lowess_kwargs : dict [ str , float | int ] = {"frac" : 0.2 , "it" : 0 }
24+ RANDOM_SEED : int = 8927
25+ rng : np . random . Generator = np .random .default_rng (RANDOM_SEED )
2626
2727
2828def _smoothed_gaussian_random_walk (
29- gaussian_random_walk_mu , gaussian_random_walk_sigma , N , lowess_kwargs
30- ):
29+ gaussian_random_walk_mu : float ,
30+ gaussian_random_walk_sigma : float ,
31+ N : int ,
32+ lowess_kwargs : dict ,
33+ ) -> tuple [np .ndarray , np .ndarray ]:
3134 """
32- Generates Gaussian random walk data and applies LOWESS
35+ Generates Gaussian random walk data and applies LOWESS.
3336
3437 :param gaussian_random_walk_mu:
3538 Mean of the random walk
@@ -48,12 +51,12 @@ def _smoothed_gaussian_random_walk(
4851
4952
5053def generate_synthetic_control_data (
51- N = 100 ,
52- treatment_time = 70 ,
53- grw_mu = 0.25 ,
54- grw_sigma = 1 ,
55- lowess_kwargs = default_lowess_kwargs ,
56- ):
54+ N : int = 100 ,
55+ treatment_time : int = 70 ,
56+ grw_mu : float = 0.25 ,
57+ grw_sigma : float = 1 ,
58+ lowess_kwargs : dict = default_lowess_kwargs ,
59+ ) -> tuple [ pd . DataFrame , np . ndarray ] :
5760 """
5861 Generates data for synthetic control example.
5962
@@ -73,7 +76,6 @@ def generate_synthetic_control_data(
7376 >>> from causalpy.data.simulate_data import generate_synthetic_control_data
7477 >>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
7578 """
76-
7779 # 1. Generate non-treated variables
7880 df = pd .DataFrame (
7981 {
@@ -108,8 +110,12 @@ def generate_synthetic_control_data(
108110
109111
110112def generate_time_series_data (
111- N = 100 , treatment_time = 70 , beta_temp = - 1 , beta_linear = 0.5 , beta_intercept = 3
112- ):
113+ N : int = 100 ,
114+ treatment_time : int = 70 ,
115+ beta_temp : float = - 1 ,
116+ beta_linear : float = 0.5 ,
117+ beta_intercept : float = 3 ,
118+ ) -> pd .DataFrame :
113119 """
114120 Generates interrupted time series example data
115121
@@ -155,7 +161,9 @@ def generate_time_series_data(
155161 return df
156162
157163
158- def generate_time_series_data_seasonal (treatment_time ):
164+ def generate_time_series_data_seasonal (
165+ treatment_time : pd .Timestamp ,
166+ ) -> pd .DataFrame :
159167 """
160168 Generates 10 years of monthly data with seasonality
161169 """
@@ -169,11 +177,13 @@ def generate_time_series_data_seasonal(treatment_time):
169177 t = df .index ,
170178 ).set_index ("date" , drop = True )
171179 month_effect = np .array ([11 , 13 , 12 , 15 , 19 , 23 , 21 , 28 , 20 , 17 , 15 , 12 ])
172- df ["y" ] = 0.2 * df ["t" ] + 2 * month_effect [df .month .values - 1 ]
180+ df ["y" ] = 0.2 * df ["t" ] + 2 * month_effect [np . asarray ( df .month .values ) - 1 ]
173181
174182 N = df .shape [0 ]
175183 idx = np .arange (N )[df .index > treatment_time ]
176- df ["causal effect" ] = 100 * gamma (10 ).pdf (np .arange (0 , N , 1 ) - np .min (idx ))
184+ df ["causal effect" ] = 100 * gamma (10 ).pdf (
185+ np .array (np .arange (0 , N , 1 )) - int (np .min (idx ))
186+ )
177187
178188 df ["y" ] += df ["causal effect" ]
179189 df ["y" ] += norm (0 , 2 ).rvs (N )
@@ -183,7 +193,9 @@ def generate_time_series_data_seasonal(treatment_time):
183193 return df
184194
185195
186- def generate_time_series_data_simple (treatment_time , slope = 0.0 ):
196+ def generate_time_series_data_simple (
197+ treatment_time : pd .Timestamp , slope : float = 0.0
198+ ) -> pd .DataFrame :
187199 """Generate simple interrupted time series data, with no seasonality or temporal
188200 structure.
189201 """
@@ -205,7 +217,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
205217 return df
206218
207219
208- def generate_did ():
220+ def generate_did () -> pd . DataFrame :
209221 """
210222 Generate Difference in Differences data
211223
@@ -223,8 +235,14 @@ def generate_did():
223235
224236 # local functions
225237 def outcome (
226- t , control_intercept , treat_intercept_delta , trend , Δ , group , post_treatment
227- ):
238+ t : np .ndarray ,
239+ control_intercept : float ,
240+ treat_intercept_delta : float ,
241+ trend : float ,
242+ Δ : float ,
243+ group : np .ndarray ,
244+ post_treatment : np .ndarray ,
245+ ) -> np .ndarray :
228246 """Compute the outcome of each unit"""
229247 return (
230248 control_intercept
@@ -244,21 +262,21 @@ def outcome(
244262 df ["post_treatment" ] = df ["t" ] > intervention_time
245263
246264 df ["y" ] = outcome (
247- df ["t" ],
265+ np . asarray ( df ["t" ]) ,
248266 control_intercept ,
249267 treat_intercept_delta ,
250268 trend ,
251269 Δ ,
252- df ["group" ],
253- df ["post_treatment" ],
270+ np . asarray ( df ["group" ]) ,
271+ np . asarray ( df ["post_treatment" ]) ,
254272 )
255273 df ["y" ] += rng .normal (0 , 0.1 , df .shape [0 ])
256274 return df
257275
258276
259277def generate_regression_discontinuity_data (
260- N = 100 , true_causal_impact = 0.5 , true_treatment_threshold = 0.0
261- ):
278+ N : int = 100 , true_causal_impact : float = 0.5 , true_treatment_threshold : float = 0.0
279+ ) -> pd . DataFrame :
262280 """
263281 Generate regression discontinuity example data
264282
@@ -272,12 +290,12 @@ def generate_regression_discontinuity_data(
272290 ... ) # doctest: +SKIP
273291 """
274292
275- def is_treated (x ) :
293+ def is_treated (x : np . ndarray ) -> np . ndarray :
276294 """Check if x was treated"""
277295 return np .greater_equal (x , true_treatment_threshold )
278296
279- def impact (x ) :
280- """Assign true_causal_impact to all treaated entries"""
297+ def impact (x : np . ndarray ) -> np . ndarray :
298+ """Assign true_causal_impact to all treated entries"""
281299 y = np .zeros (len (x ))
282300 y [is_treated (x )] = true_causal_impact
283301 return y
@@ -289,8 +307,11 @@ def impact(x):
289307
290308
291309def generate_ancova_data (
292- N = 200 , pre_treatment_means = np .array ([10 , 12 ]), treatment_effect = 2 , sigma = 1
293- ):
310+ N : int = 200 ,
311+ pre_treatment_means : np .ndarray = np .array ([10 , 12 ]),
312+ treatment_effect : int = 2 ,
313+ sigma : int = 1 ,
314+ ) -> pd .DataFrame :
294315 """
295316 Generate ANCOVA example data
296317
@@ -310,7 +331,7 @@ def generate_ancova_data(
310331 return df
311332
312333
313- def generate_geolift_data ():
334+ def generate_geolift_data () -> pd . DataFrame :
314335 """Generate synthetic data for a geolift example. This will consists of 6 untreated
315336 countries. The treated unit `Denmark` is a weighted combination of the untreated
316337 units. We additionally specify a treatment effect which takes effect after the
@@ -360,7 +381,7 @@ def generate_geolift_data():
360381 return df
361382
362383
363- def generate_multicell_geolift_data ():
384+ def generate_multicell_geolift_data () -> pd . DataFrame :
364385 """Generate synthetic data for a geolift example. This will consists of 6 untreated
365386 countries. The treated unit `Denmark` is a weighted combination of the untreated
366387 units. We additionally specify a treatment effect which takes effect after the
@@ -422,7 +443,9 @@ def generate_multicell_geolift_data():
422443# -----------------
423444
424445
425- def generate_seasonality (n = 12 , amplitude = 1 , length_scale = 0.5 ):
446+ def generate_seasonality (
447+ n : int = 12 , amplitude : int = 1 , length_scale : float = 0.5
448+ ) -> np .ndarray :
426449 """Generate monthly seasonality by sampling from a Gaussian process with a
427450 Gaussian kernel, using numpy code"""
428451 # Generate the covariance matrix
@@ -436,14 +459,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
436459 return seasonality
437460
438461
439- def periodic_kernel (x1 , x2 , period = 1 , length_scale = 1 , amplitude = 1 ):
462+ def periodic_kernel (
463+ x1 : np .ndarray ,
464+ x2 : np .ndarray ,
465+ period : int = 1 ,
466+ length_scale : float = 1.0 ,
467+ amplitude : int = 1 ,
468+ ) -> np .ndarray :
440469 """Generate a periodic kernel for gaussian process"""
441470 return amplitude ** 2 * np .exp (
442471 - 2 * np .sin (np .pi * np .abs (x1 - x2 ) / period ) ** 2 / length_scale ** 2
443472 )
444473
445474
446- def create_series (n = 52 , amplitude = 1 , length_scale = 2 , n_years = 4 , intercept = 3 ):
475+ def create_series (
476+ n : int = 52 ,
477+ amplitude : int = 1 ,
478+ length_scale : int = 2 ,
479+ n_years : int = 4 ,
480+ intercept : int = 3 ,
481+ ) -> np .ndarray :
447482 """
448483 Returns numpy tile with generated seasonality data repeated over
449484 multiple years
0 commit comments