@@ -89,7 +89,7 @@ def build_model(self, X, y, coords) -> None:
8989        """Build the model, must be implemented by subclass.""" 
9090        raise  NotImplementedError ("This method must be implemented by a subclass" )
9191
92-     def  _data_setter (self , X ) ->  None :
92+     def  _data_setter (self , X :  xr . DataArray ) ->  None :
9393        """ 
9494        Set data for the model. 
9595
@@ -105,6 +105,9 @@ def _data_setter(self, X) -> None:
105105        """ 
106106        new_no_of_observations  =  X .shape [0 ]
107107
108+         # Use integer indices for obs_ind to avoid datetime compatibility issues with PyMC 
109+         obs_coords  =  np .arange (new_no_of_observations )
110+ 
108111        # Check if this model has multiple treated units 
109112        if  hasattr (self , "idata" ) and  self .idata  is  not   None :
110113            posterior  =  self .idata .posterior 
@@ -125,13 +128,13 @@ def _data_setter(self, X) -> None:
125128                # Multi-unit case or single unit with treated_units dimension 
126129                pm .set_data (
127130                    {"X" : X , "y" : np .zeros ((new_no_of_observations , n_treated_units ))},
128-                     coords = {"obs_ind" : np . arange ( new_no_of_observations ) },
131+                     coords = {"obs_ind" : obs_coords },
129132                )
130133            else :
131134                # Other model types (e.g., LinearRegression) without treated_units dimension 
132135                pm .set_data (
133136                    {"X" : X , "y" : np .zeros (new_no_of_observations )},
134-                     coords = {"obs_ind" : np . arange ( new_no_of_observations ) },
137+                     coords = {"obs_ind" : obs_coords },
135138                )
136139
137140    def  fit (self , X , y , coords : Optional [Dict [str , Any ]] =  None ) ->  None :
@@ -154,7 +157,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
154157            )
155158        return  self .idata 
156159
157-     def  predict (self , X ):
160+     def  predict (self , X :  xr . DataArray ):
158161        """ 
159162        Predict data given input data `X` 
160163
@@ -166,16 +169,19 @@ def predict(self, X):
166169        # sample_posterior_predictive() if provided in sample_kwargs. 
167170        random_seed  =  self .sample_kwargs .get ("random_seed" , None )
168171        self ._data_setter (X )
169-         with  self :   # sample with new input data 
172+         with  self :
170173            pp  =  pm .sample_posterior_predictive (
171174                self .idata ,
172175                var_names = ["y_hat" , "mu" ],
173176                progressbar = False ,
174177                random_seed = random_seed ,
175178            )
176179
177-         # TODO: This is a bit of a hack. Maybe it could be done properly in _data_setter? 
178-         if  isinstance (X , xr .DataArray ):
180+         # Assign coordinates from input X to ensure xarray operations work correctly 
181+         # This is necessary because PyMC uses integer indices internally, but we need 
182+         # to preserve the original coordinates (e.g., datetime indices) for proper 
183+         # alignment with other xarray operations like calculate_impact() 
184+         if  isinstance (X , xr .DataArray ) and  "obs_ind"  in  X .coords :
179185            pp ["posterior_predictive" ] =  pp ["posterior_predictive" ].assign_coords (
180186                obs_ind = X .obs_ind 
181187            )
0 commit comments