@@ -41,6 +41,13 @@ class DoubleMLPanelData(DoubleMLData):
4141 The instrumental variable(s).
4242 Default is ``None``.
4343
44+ static_panel : bool
45+ Indicates whether the data model corresponds to a static
46+ panel data approach (``True``) or to staggered adoption panel data
47+ (``False``). In the latter case, the treatment groups/values are defined in terms of the first time of
48+ treatment exposure.
49+ Default is ``False``.
50+
4451 use_other_treat_as_covariate : bool
4552 Indicates whether in the multiple-treatment case the other treatment variables should be added as covariates.
4653 Default is ``True``.
@@ -82,31 +89,40 @@ def __init__(
8289 id_col ,
8390 x_cols = None ,
8491 z_cols = None ,
92+ static_panel = False ,
8593 use_other_treat_as_covariate = True ,
8694 force_all_x_finite = True ,
8795 datetime_unit = "M" ,
8896 ):
8997 DoubleMLBaseData .__init__ (self , data )
9098
99+ self ._static_panel = static_panel
100+
91101 # we need to set id_col (needs _data) before call to the super __init__ because of the x_cols setter
92102 self .id_col = id_col
93- self ._datetime_unit = _is_valid_datetime_unit (datetime_unit )
94103 self ._set_id_var ()
95-
96104 # Set time column before calling parent constructor
97105 self .t_col = t_col
106+ self ._datetime_unit = _is_valid_datetime_unit (datetime_unit )
107+
108+ if not self .static_panel :
109+ cluster_cols = None
110+ force_all_d_finite = False
111+ else :
112+ cluster_cols = id_col
113+ force_all_d_finite = True
98114
99- # Call parent constructor
100115 DoubleMLData .__init__ (
101116 self ,
102117 data = data ,
103118 y_col = y_col ,
104119 d_cols = d_cols ,
105120 x_cols = x_cols ,
106121 z_cols = z_cols ,
122+ cluster_cols = cluster_cols ,
107123 use_other_treat_as_covariate = use_other_treat_as_covariate ,
108124 force_all_x_finite = force_all_x_finite ,
109- force_all_d_finite = False ,
125+ force_all_d_finite = force_all_d_finite ,
110126 )
111127
112128 # reset index to ensure a simple RangeIndex
@@ -115,15 +131,15 @@ def __init__(
115131 # Set time variable array after data is loaded
116132 self ._set_time_var ()
117133
118- if self .n_treat != 1 :
119- raise ValueError ("Only one treatment column is allowed for panel data." )
120-
121134 self ._check_disjoint_sets_id_col ()
122135
123136 # intialize the unique values of g and t
124137 self ._g_values = np .sort (np .unique (self .d )) # unique values of g
125138 self ._t_values = np .sort (np .unique (self .t )) # unique values of t
126139
140+ if self .n_treat != 1 :
141+ raise ValueError ("Only one treatment column is allowed for panel data." )
142+
127143 def __str__ (self ):
128144 data_summary = self ._data_summary_str ()
129145 buf = io .StringIO ()
@@ -146,9 +162,10 @@ def _data_summary_str(self):
146162 f"Instrument variable(s): { self .z_cols } \n "
147163 f"Time variable: { self .t_col } \n "
148164 f"Id variable: { self .id_col } \n "
165+ f"Static panel data: { self .static_panel } \n "
149166 )
150167
151- data_summary += f"No. Unique Ids: { self .n_ids } \n " f"No. Observations: { self .n_obs } \n "
168+ data_summary += f"No. Unique Ids: { self .n_ids } \n " f"No. Observations: { self .n_obs } "
152169 return data_summary
153170
154171 @classmethod
@@ -296,6 +313,11 @@ def n_t_periods(self):
296313 """
297314 return len (self .t_values )
298315
316+ @property
317+ def static_panel (self ):
318+ """Indicates whether the data model corresponds to a static panel data approach."""
319+ return self ._static_panel
320+
299321 def _get_optional_col_sets (self ):
300322 base_optional_col_sets = super ()._get_optional_col_sets ()
301323 id_col_set = {self .id_col }
@@ -370,4 +392,9 @@ def _set_id_var(self):
370392 def _set_time_var (self ):
371393 """Set the time variable array."""
372394 if hasattr (self , "_data" ) and self .t_col in self .data .columns :
373- self ._t = self .data .loc [:, self .t_col ]
395+ t_values = self .data .loc [:, self .t_col ]
396+ expected_dtypes = (np .integer , np .floating , np .datetime64 )
397+ if not any (np .issubdtype (t_values .dtype , dt ) for dt in expected_dtypes ):
398+ raise ValueError (f"Invalid data type for time variable: expected one of { expected_dtypes } ." )
399+ else :
400+ self ._t = t_values
0 commit comments