File tree Expand file tree Collapse file tree 1 file changed +7
-9
lines changed Expand file tree Collapse file tree 1 file changed +7
-9
lines changed Original file line number Diff line number Diff line change @@ -109,20 +109,18 @@ def _data_setter(self, X) -> None:
109109 has_treated_units = False
110110
111111 with self :
112- if has_treated_units :
113- # Get the number of treated units from the model coordinates
114- treated_units_coord = getattr (self , "coords" , {}).get (
115- "treated_units" , []
116- )
117- n_treated_units = (
118- len (treated_units_coord ) if treated_units_coord is not None else 1
119- )
112+ # Get the number of treated units from the model coordinates
113+ treated_units_coord = getattr (self , "coords" , {}).get ("treated_units" , [])
114+ n_treated_units = len (treated_units_coord ) if treated_units_coord else 1
115+
116+ if n_treated_units > 1 or has_treated_units :
117+ # Multi-unit case or single unit with treated_units dimension
120118 pm .set_data (
121119 {"X" : X , "y" : np .zeros ((new_no_of_observations , n_treated_units ))},
122120 coords = {"obs_ind" : np .arange (new_no_of_observations )},
123121 )
124122 else :
125- # Legacy case - this shouldn't happen with new WeightedSumFitter
123+ # Other model types (e.g., LinearRegression) without treated_units dimension
126124 pm .set_data (
127125 {"X" : X , "y" : np .zeros (new_no_of_observations )},
128126 coords = {"obs_ind" : np .arange (new_no_of_observations )},
You can’t perform that action at this time.
0 commit comments