@@ -530,26 +530,30 @@ class InterventionTimeEstimator(PyMCModel):
530530 --------
531531 >>> import causalpy as cp
532532 >>> import numpy as np
533- >>> from causalpy.pymc_models import InterventionTimeEstimator
534- >>> df = cp.load_data("its")
535- >>> y = df["y"].values
536- >>> t = df["t"].values
537- >>> coords = {"seasons": range(12)} # The data is monthly
538- >>> estimator = InterventionTimeEstimator()
539- >>> # We are trying to capture an impulse in the number of death per month due to Covid.
540- >>> estimator.fit(
541- ... t,
542- ... y,
543- ... coords,
544- ... priors={"impulse":[]}
545- ... )
546- Inference data...
533+ >>> from patsy import build_design_matrices, dmatrices
534+ >>> from causalpy.pymc_models import InterventionTimeEstimator as ITE
535+ >>> data = cp.load_data("its")
536+ >>> formula="y ~ 1 + t + C(month)"
537+ >>> y, X = dmatrices(formula, data)
538+ >>> outcome_variable_name = y.design_info.column_names[0]
539+ >>> labels = X.design_info.column_names
540+ >>> _y, _X = np.asarray(y), np.asarray(X)
541+ >>> COORDS = {"coeffs":labels, "obs_ind": np.arange(_X.shape[0])}
542+ >>> model = ITE(sample_kwargs={"draws" : 10, "tune":10, "progressbar":False}) # For a quick overview. Remove sample_kwargs parameter for better performance
543+ >>> model.set_time_range(None)
544+ >>> model.fit(X=_X, y=_y, coords=COORDS)
545+ Inference ...
547546 """
548547
549- def build_model (self , t , y , coords , time_range , grain_season , priors ):
548+ def __init__ (self , priors = {}, sample_kwargs = None ):
549+ super ().__init__ (sample_kwargs )
550+ self .priors = priors
551+
552+ def build_model (self , X , t , y , coords ):
550553 """
551554 Defines the PyMC model
552555
556+ :param X: A dataframe of the covariates
553557 :param t: An array of values representing the time over which y is spread
554558 :param y: An array of values representing our outcome y
555559 :param coords: An optional dictionary with the coordinate names for our instruments.
@@ -564,80 +568,134 @@ def build_model(self, t, y, coords, time_range, grain_season, priors):
564568 with self :
565569 self .add_coords (coords )
566570
567- if time_range is None :
568- time_range = (t .min (), t .max ())
569-
571+ t = pm .Data ("t" , t , dims = "obs_ind" )
572+ X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
573+ y = pm .Data ("y" , y [:, 0 ], dims = "obs_ind" )
574+ lower_bound = pm .Data ("lower_bound" , self .time_range [0 ])
575+ upper_bound = pm .Data ("upper_bound" , self .time_range [1 ])
570576 # --- Priors ---
571577 switchpoint = pm .Uniform (
572- "switchpoint" , lower = time_range [ 0 ] , upper = time_range [ 1 ]
578+ "switchpoint" , lower = lower_bound , upper = upper_bound
573579 )
574- alpha = pm .Normal (name = "alpha" , mu = 0 , sigma = 50 )
575- beta = pm .Normal (name = "beta" , mu = 0 , sigma = 50 )
576- seasons = 0
577- if "seasons" in coords and len (coords ["seasons" ]) > 0 :
578- season_idx = np .arange (len (y )) // grain_season % len (coords ["seasons" ])
579- seasons_effect = pm .Normal (
580- "seasons_effect" , mu = 0 , sigma = 50 , dims = "seasons"
581- )
582- seasons = seasons_effect [season_idx ]
580+ beta = pm .Normal (name = "beta" , mu = 0 , sigma = 50 , dims = "coeffs" )
583581
584582 # --- Intervention effect ---
585583 level = trend = impulse = 0
586584
587- if "level" in priors :
585+ if "level" in self . priors :
588586 mu , sigma = (
589587 (0 , 50 )
590- if len (priors ["level" ]) != 2
591- else (priors ["level" ][0 ], priors ["level" ][1 ])
588+ if len (self . priors ["level" ]) != 2
589+ else (self . priors ["level" ][0 ], self . priors ["level" ][1 ])
592590 )
593591 level = pm .Normal (
594592 "level" ,
595593 mu = mu ,
596594 sigma = sigma ,
597595 )
598- if "trend" in priors :
596+ if "trend" in self . priors :
599597 mu , sigma = (
600598 (0 , 50 )
601- if len (priors ["trend" ]) != 2
602- else (priors ["trend" ][0 ], priors ["trend" ][1 ])
599+ if len (self . priors ["trend" ]) != 2
600+ else (self . priors ["trend" ][0 ], self . priors ["trend" ][1 ])
603601 )
604602 trend = pm .Normal ("trend" , mu = mu , sigma = sigma )
605- if "impulse" in priors :
603+ if "impulse" in self . priors :
606604 mu , sigma1 , sigma2 = (
607605 (0 , 50 , 50 )
608- if len (priors ["impulse" ]) != 3
606+ if len (self . priors ["impulse" ]) != 3
609607 else (
610- priors ["impulse" ][0 ],
611- priors ["impulse" ][1 ],
612- priors ["impulse" ][2 ],
608+ self . priors ["impulse" ][0 ],
609+ self . priors ["impulse" ][1 ],
610+ self . priors ["impulse" ][2 ],
613611 )
614612 )
615613 impulse_amplitude = pm .Normal ("impulse_amplitude" , mu = mu , sigma = sigma1 )
616614 decay_rate = pm .HalfNormal ("decay_rate" , sigma = sigma2 )
617- impulse = impulse_amplitude * pm .math .exp (
618- - decay_rate * abs (t - switchpoint )
615+ impulse = pm .Deterministic (
616+ "impulse" ,
617+ impulse_amplitude
618+ * pm .math .exp (- decay_rate * pm .math .abs (t - switchpoint )),
619619 )
620620
621621 # --- Parameterization ---
622622 weight = pm .math .sigmoid (t - switchpoint )
623- # Compute and store the modelled time series
624- mu_ts = pm .Deterministic (name = "mu_ts " , var = alpha + beta * t + seasons )
623+ # Compute and store the base time series
624+ mu = pm .Deterministic (name = "mu " , var = pm . math . dot ( X , beta ) )
625625 # Compute and store the modelled intervention effect
626626 mu_in = pm .Deterministic (
627627 name = "mu_in" , var = level + trend * (t - switchpoint ) + impulse
628628 )
629- # Compute and store the the sum of the intervention and the time series
630- mu = pm .Deterministic ("mu " , mu_ts + weight * mu_in )
629+ # Compute and store the sum of the base time series and the intervention's effect
630+ mu_ts = pm .Deterministic ("mu_ts " , mu + weight * mu_in )
631631 sigma = pm .HalfNormal ("sigma" , 1 )
632632
633633 # --- Likelihood ---
634- pm .Normal ("y_hat" , mu = mu , sigma = sigma , observed = y )
634+ # Likelihood of the base time series
635+ pm .Normal ("y_hat" , mu = mu , sigma = sigma , dims = "obs_ind" )
636+ # Likelihodd of the base time series and the intervention's effect
637+ pm .Normal ("y_ts" , mu = mu_ts , sigma = sigma , observed = y , dims = "obs_ind" )
635638
636- def fit (self , t , y , coords , time_range = None , grain_season = 1 , priors = {}, n = 1000 ) :
637- """
638- Draw samples from posterior distribution
639+ def fit (self , X , y , coords : Optional [ Dict [ str , Any ]] = None ) -> None :
640+ """Draw samples from posterior, prior predictive, and posterior predictive
641+ distributions, placing them in the model's idata attribute.
639642 """
640- self .build_model (t , y , coords , time_range , grain_season , priors )
643+
644+ # Ensure random_seed is used in sample_prior_predictive() and
645+ # sample_posterior_predictive() if provided in sample_kwargs.
646+ random_seed = self .sample_kwargs .get ("random_seed" , None )
647+ t = X [:, - 1 ]
648+ if self .time_range is None :
649+ self .time_range = (t .min (), t .max ())
650+ self .build_model (X , t , y , coords )
641651 with self :
642- self .idata = pm .sample (n , progressbar = False , ** self .sample_kwargs )
652+ self .idata = pm .sample (max_treedepth = 15 , ** self .sample_kwargs )
653+ self .idata .extend (pm .sample_prior_predictive (random_seed = random_seed ))
654+ self .idata .extend (
655+ pm .sample_posterior_predictive (
656+ self .idata , progressbar = False , random_seed = random_seed
657+ )
658+ )
643659 return self .idata
660+
661+ def predict (self , X ):
662+ """
663+ Predict data given input data `X`
664+
665+ .. caution::
666+ Results in KeyError if model hasn't been fit.
667+ """
668+
669+ # Ensure random_seed is used in sample_prior_predictive() and
670+ # sample_posterior_predictive() if provided in sample_kwargs.
671+ random_seed = self .sample_kwargs .get ("random_seed" , None )
672+ t = X [:, - 1 ]
673+ self ._data_setter (X , t )
674+ with self : # sample with new input data
675+ post_pred = pm .sample_posterior_predictive (
676+ self .idata ,
677+ var_names = ["y_hat" , "y_ts" , "mu" , "mu_ts" , "mu_in" ],
678+ progressbar = False ,
679+ random_seed = random_seed ,
680+ )
681+ return post_pred
682+
683+ def _data_setter (self , X , t ) -> None :
684+ """
685+ Set data for the model.
686+
687+ This method is used internally to register new data for the model for
688+ prediction.
689+ """
690+ new_no_of_observations = X .shape [0 ]
691+ with self :
692+ pm .set_data (
693+ {"X" : X , "t" : t , "y" : np .zeros (new_no_of_observations )},
694+ coords = {"obs_ind" : np .arange (new_no_of_observations )},
695+ )
696+
697+ def set_time_range (self , time_range ):
698+ """
699+ Set time_range.
700+ """
701+ self .time_range = time_range
0 commit comments