@@ -539,8 +539,9 @@ class InterventionTimeEstimator(PyMCModel):
539539 >>> labels = X.design_info.column_names
540540 >>> _y, _X = np.asarray(y), np.asarray(X)
541541 >>> COORDS = {"coeffs":labels, "obs_ind": np.arange(_X.shape[0])}
542- >>> model = ITE(sample_kwargs={"draws" : 10, "tune":10, "progressbar":False}) # For a quick overview. Remove sample_kwargs parameter for better performance
543- >>> model.set_time_range(None)
542+ >>> model = ITE(sample_kwargs={"draws" : 10, "tune":10, "progressbar":False})
543+ >>> model.set_time_range(None, data)
544+ >>> model.set_timeline(-1)
544545 >>> model.fit(X=_X, y=_y, coords=COORDS)
545546 Inference ...
546547 """
@@ -549,7 +550,7 @@ def __init__(self, priors={}, sample_kwargs=None):
549550 super ().__init__ (sample_kwargs )
550551 self .priors = priors
551552
552- def build_model (self , X , t , y , coords ):
553+ def build_model (self , X , y , coords ):
553554 """
554555 Defines the PyMC model
555556
@@ -568,7 +569,7 @@ def build_model(self, X, t, y, coords):
568569 with self :
569570 self .add_coords (coords )
570571
571- t = pm .Data ("t" , t , dims = "obs_ind" )
572+ t = pm .Data ("t" , X [:, self . timeline ] , dims = "obs_ind" )
572573 X = pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
573574 y = pm .Data ("y" , y [:, 0 ], dims = "obs_ind" )
574575 lower_bound = pm .Data ("lower_bound" , self .time_range [0 ])
@@ -644,10 +645,9 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
644645 # Ensure random_seed is used in sample_prior_predictive() and
645646 # sample_posterior_predictive() if provided in sample_kwargs.
646647 random_seed = self .sample_kwargs .get ("random_seed" , None )
647- t = X [:, - 1 ]
648648 if self .time_range is None :
649- self .time_range = (t . min (), t .max ())
650- self .build_model (X , t , y , coords )
649+ self .time_range = (X [:, self . timeline ]. min (), X [:, self . timeline ] .max ())
650+ self .build_model (X , y , coords )
651651 with self :
652652 self .idata = pm .sample (max_treedepth = 15 , ** self .sample_kwargs )
653653 self .idata .extend (pm .sample_prior_predictive (random_seed = random_seed ))
@@ -669,7 +669,7 @@ def predict(self, X):
669669 # Ensure random_seed is used in sample_prior_predictive() and
670670 # sample_posterior_predictive() if provided in sample_kwargs.
671671 random_seed = self .sample_kwargs .get ("random_seed" , None )
672- t = X [:, - 1 ]
672+ t = X [:, self . timeline ]
673673 self ._data_setter (X , t )
674674 with self : # sample with new input data
675675 post_pred = pm .sample_posterior_predictive (
@@ -693,3 +693,21 @@ def _data_setter(self, X, t) -> None:
693693 {"X" : X , "t" : t , "y" : np .zeros (new_no_of_observations )},
694694 coords = {"obs_ind" : np .arange (new_no_of_observations )},
695695 )
696+
697+ def set_time_range (self , time_range , data ):
698+ """
699+ Set time_range.
700+ """
701+ if time_range is None :
702+ self .time_range = time_range
703+ else :
704+ self .time_range = (
705+ data ["t" ].loc [time_range [0 ]],
706+ data ["t" ].loc [time_range [1 ]],
707+ )
708+
709+ def set_timeline (self , index ):
710+ """
711+ Set the index of the timeline in the given covariates
712+ """
713+ self .timeline = index
0 commit comments