@@ -560,23 +560,21 @@ class InterventionTimeEstimator(PyMCModel):
560560        ...     coords={"obs_ind": data.index}, 
561561        ...     ) 
562562        >>> COORDS = {"coeffs":labels, "obs_ind": np.arange(_X.shape[0])} 
563-         >>> model = ITE(time_variable_name="t",  treatment_effect_type="level", sample_kwargs={"draws" : 10, "tune":10, "progressbar":False}) 
563+         >>> model = ITE(treatment_effect_type="level", sample_kwargs={"draws" : 10, "tune":10, "progressbar":False}) 
564564        >>> model.set_time_range(None, data) 
565565        >>> model.fit(X=_X, y=_y, coords=COORDS) 
566566        Inference ... 
567567    """ 
568568
569569    def  __init__ (
570570        self ,
571-         time_variable_name : str ,
572571        treatment_effect_type : str  |  list [str ],
573572        treatment_effect_param = None ,
574573        sample_kwargs = None ,
575574    ):
576575        """ 
577576        Initializes the InterventionTimeEstimator model. 
578577
579-         :param time_variable_name: name of the column representing time among the covariates 
580578        :param treatment_effect_type: Optional dictionary that specifies prior parameters for the 
581579            intervention effects. Expected keys are: 
582580                - "level": [mu, sigma] 
@@ -586,7 +584,6 @@ def __init__(
586584            If the associated list is incomplete, default values will be used. 
587585        :param sample_kwargs: Optional dictionary of arguments passed to pm.sample(). 
588586        """ 
589-         self .time_variable_name  =  time_variable_name 
590587
591588        super ().__init__ (sample_kwargs )
592589
@@ -657,7 +654,7 @@ def build_model(self, X, y, coords):
657654        with  self :
658655            self .add_coords (coords )
659656
660-             t  =  pm .Data ("t" , X . sel ( coeffs = self . time_variable_name ), dims = "obs_ind" )
657+             t  =  pm .Data ("t" , np . arange ( len ( X ) ), dims = "obs_ind" )
661658            X  =  pm .Data ("X" , X , dims = ["obs_ind" , "coeffs" ])
662659            y  =  pm .Data ("y" , y , dims = "obs_ind" )
663660
@@ -768,7 +765,7 @@ def _data_setter(self, X) -> None:
768765            pm .set_data (
769766                {
770767                    "X" : X ,
771-                     "t" : X . sel ( coeffs = self . time_variable_name ),
768+                     "t" : np . arange ( len ( X ) ),
772769                    "y" : np .zeros (new_no_of_observations ),
773770                },
774771                coords = {"obs_ind" : np .arange (new_no_of_observations )},
@@ -794,11 +791,11 @@ def set_time_range(self, time_range, data):
794791        """ 
795792        if  time_range  is  None :
796793            self .time_range  =  (
797-                 data [ self . time_variable_name ]. min () ,
798-                 data [ self . time_variable_name ]. max ( ),
794+                 0 ,
795+                 len ( data ),
799796            )
800797        else :
801798            self .time_range  =  (
802-                 data [ self . time_variable_name ]. loc [ time_range [0 ]] ,
803-                 data [ self . time_variable_name ]. loc [ time_range [1 ]] ,
799+                 data . index . get_loc ( time_range [0 ]) ,
800+                 data . index . get_loc ( time_range [1 ]) ,
804801            )
0 commit comments