-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathpanel_data.py
More file actions
404 lines (347 loc) · 13.4 KB
/
panel_data.py
File metadata and controls
404 lines (347 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
import io
import numpy as np
import pandas as pd
from sklearn.utils import assert_all_finite
from doubleml.data.base_data import DoubleMLBaseData, DoubleMLData
from doubleml.data.utils.panel_data_utils import _is_valid_datetime_unit
class DoubleMLPanelData(DoubleMLData):
"""Double machine learning data-backend for panel data in long format.
:class:`DoubleMLPanelData` objects can be initialized from
:class:`pandas.DataFrame` as well as :class:`numpy.ndarray` objects.
Parameters
----------
data : :class:`pandas.DataFrame`
The data.
y_col : str
The outcome variable.
d_cols : str or list
The treatment variable(s) indicating the treatment groups in terms of first time of treatment exposure.
t_col : str
The time variable indicating the time.
id_col : str
Unique unit identifier.
x_cols : None, str or list
The covariates.
If ``None``, all variables (columns of ``data``) which are neither specified as outcome variable ``y_col``, nor
treatment variables ``d_cols``, nor instrumental variables ``z_cols`` are used as covariates.
Default is ``None``.
z_cols : None, str or list
The instrumental variable(s).
Default is ``None``.
static_panel : bool
Indicates whether the data model corresponds to a static
panel data approach (``True``) or to staggered adoption panel data
(``False``). In the latter case, the treatment groups/values are defined in terms of the first time of
treatment exposure.
Default is ``False``.
use_other_treat_as_covariate : bool
Indicates whether in the multiple-treatment case the other treatment variables should be added as covariates.
Default is ``True``.
force_all_x_finite : bool or str
Indicates whether to raise an error on infinite values and / or missings in the covariates ``x``.
Possible values are: ``True`` (neither missings ``np.nan``, ``pd.NA`` nor infinite values ``np.inf`` are
allowed), ``False`` (missings and infinite values are allowed), ``'allow-nan'`` (only missings are allowed).
Note that the choice ``False`` and ``'allow-nan'`` are only reasonable if the machine learning methods used
for the nuisance functions are capable to provide valid predictions with missings and / or infinite values
in the covariates ``x``.
Default is ``True``.
datetime_unit : str
The unit of the time and treatment variable (if datetime type).
Examples
--------
>>> from doubleml.did.datasets import make_did_CS2021
>>> from doubleml import DoubleMLPanelData
>>> df = make_did_CS2021(n_obs=500)
>>> dml_data = DoubleMLPanelData(
... df,
... y_col="y",
... d_cols="d",
... id_col="id",
... t_col="t",
... x_cols=["Z1", "Z2", "Z3", "Z4"],
... datetime_unit="M"
... )
"""
def __init__(
self,
data,
y_col,
d_cols,
t_col,
id_col,
x_cols=None,
z_cols=None,
static_panel=False,
use_other_treat_as_covariate=True,
force_all_x_finite=True,
datetime_unit="M",
):
DoubleMLBaseData.__init__(self, data)
self._static_panel = static_panel
# we need to set id_col (needs _data) before call to the super __init__ because of the x_cols setter
self.id_col = id_col
self._set_id_var()
# Set time column before calling parent constructor
self.t_col = t_col
self._datetime_unit = _is_valid_datetime_unit(datetime_unit)
if not self.static_panel:
cluster_cols = None
force_all_d_finite = False
else:
cluster_cols = id_col
force_all_d_finite = True
DoubleMLData.__init__(
self,
data=data,
y_col=y_col,
d_cols=d_cols,
x_cols=x_cols,
z_cols=z_cols,
cluster_cols=cluster_cols,
use_other_treat_as_covariate=use_other_treat_as_covariate,
force_all_x_finite=force_all_x_finite,
force_all_d_finite=force_all_d_finite,
)
# reset index to ensure a simple RangeIndex
self.data.reset_index(drop=True, inplace=True)
# Set time variable array after data is loaded
self._set_time_var()
self._check_disjoint_sets_id_col()
# intialize the unique values of g and t
self._g_values = np.sort(np.unique(self.d)) # unique values of g
self._t_values = np.sort(np.unique(self.t)) # unique values of t
if self.n_treat != 1:
raise ValueError("Only one treatment column is allowed for panel data.")
def __str__(self):
data_summary = self._data_summary_str()
buf = io.StringIO()
self.data.info(verbose=False, buf=buf)
df_info = buf.getvalue()
res = (
"================== DoubleMLPanelData Object ==================\n"
+ "\n------------------ Data summary ------------------\n"
+ data_summary
+ "\n------------------ DataFrame info ------------------\n"
+ df_info
)
return res
def _data_summary_str(self):
data_summary = (
f"Outcome variable: {self.y_col}\n"
f"Treatment variable(s): {self.d_cols}\n"
f"Covariates: {self.x_cols}\n"
f"Instrument variable(s): {self.z_cols}\n"
f"Time variable: {self.t_col}\n"
f"Id variable: {self.id_col}\n"
f"Static panel data: {self.static_panel}\n"
)
data_summary += f"No. Unique Ids: {self.n_ids}\n" f"No. Observations: {self.n_obs}"
return data_summary
@classmethod
def from_arrays(cls, x, y, d, t, identifier, z=None, s=None, use_other_treat_as_covariate=True, force_all_x_finite=True):
# TODO: Implement initialization from arrays
raise NotImplementedError("from_arrays is not implemented for DoubleMLPanelData")
@property
def datetime_unit(self):
"""
The unit of the time variable.
"""
return self._datetime_unit
@property
def d(self):
"""
Array of treatment variable;
Dynamic! Depends on the currently set treatment variable;
To get an array of all treatment variables (independent of the currently set treatment variable)
call ``obj.data[obj.d_cols].values``.
"""
if pd.api.types.is_datetime64_any_dtype(self._d):
return self._d.values.astype(f"datetime64[{self.datetime_unit}]")
else:
return self._d.values
@property
def t(self):
"""
Array of time variable.
"""
if pd.api.types.is_datetime64_any_dtype(self._t):
return self._t.values.astype(f"datetime64[{self.datetime_unit}]")
else:
return self._t.values
@property
def id_col(self):
"""
The id variable.
"""
return self._id_col
@id_col.setter
def id_col(self, value):
reset_value = hasattr(self, "_id_col")
if not isinstance(value, str):
raise TypeError(
"The id variable id_col must be of str type. " f"{str(value)} of type {str(type(value))} was passed."
)
if value not in self.all_variables:
raise ValueError("Invalid id variable id_col. " f"{value} is no data column.")
self._id_col = value
if reset_value:
self._check_disjoint_sets()
self._set_id_var()
@property
def id_var(self):
"""
Array of id variable.
"""
return self._id_var.values
@property
def id_var_unique(self):
"""
Unique values of id variable.
"""
return self._id_var_unique
@property
def n_ids(self):
"""
The number of unique values for id_col.
"""
return len(self._id_var_unique)
@property
def g_col(self):
"""
The treatment variable indicating the time of treatment exposure.
"""
return self._d_cols[0]
@DoubleMLData.d_cols.setter
def d_cols(self, value):
if isinstance(value, str):
value = [value]
super(self.__class__, self.__class__).d_cols.__set__(self, value)
if hasattr(self, "_g_values"):
self._g_values = np.sort(np.unique(self.d)) # update unique values of g
@property
def g_values(self):
"""
The unique values of the treatment variable (groups) ``d``.
"""
return self._g_values
@property
def n_groups(self):
"""
The number of groups.
"""
return len(self.g_values)
@property
def t_col(self):
"""
The time variable.
"""
return self._t_col
@t_col.setter
def t_col(self, value):
if value is None:
raise TypeError("Invalid time variable t_col. Time variable required for panel data.")
if not isinstance(value, str):
raise TypeError(
"The time variable t_col must be of str type. " f"{str(value)} of type {str(type(value))} was passed."
)
# Check if data exists (during initialization it might not)
if hasattr(self, "_data") and value not in self.all_variables:
raise ValueError(f"Invalid time variable t_col. {value} is no data column.")
self._t_col = value
# Update time variable array if data is already loaded
if hasattr(self, "_data"):
self._set_time_var()
if hasattr(self, "_t_values"):
self._t_values = np.sort(np.unique(self.t)) # update unique values of t
@property
def t_values(self):
"""
The unique values of the time variable ``t``.
"""
return self._t_values
@property
def n_t_periods(self):
"""
The number of time periods.
"""
return len(self.t_values)
@property
def static_panel(self):
"""Indicates whether the data model corresponds to a static panel data approach."""
return self._static_panel
def _get_optional_col_sets(self):
base_optional_col_sets = super()._get_optional_col_sets()
id_col_set = {self.id_col}
t_col_set = {self.t_col}
return [id_col_set, t_col_set] + base_optional_col_sets
def _check_disjoint_sets(self):
# apply the standard checks from the DoubleMLData class
super(DoubleMLPanelData, self)._check_disjoint_sets()
self._check_disjoint_sets_id_col()
self._check_disjoint_sets_t_col()
def _check_disjoint_sets_id_col(self):
# special checks for the additional id variable (and the time variable)
id_col_set = {self.id_col}
y_col_set = {self.y_col}
x_cols_set = set(self.x_cols)
d_cols_set = set(self.d_cols)
z_cols_set = set(self.z_cols or [])
t_col_set = {self.t_col} # t_col is not None for panel data
# s_col not tested as not relevant for panel data
id_col_check_args = [
(y_col_set, "outcome variable", "``y_col``"),
(d_cols_set, "treatment variable", "``d_cols``"),
(x_cols_set, "covariate", "``x_cols``"),
(z_cols_set, "instrumental variable", "``z_cols``"),
(t_col_set, "time variable", "``t_col``"),
]
for set1, name, argument in id_col_check_args:
self._check_disjoint(
set1=set1,
name1=name,
arg1=argument,
set2=id_col_set,
name2="identifier variable",
arg2="``id_col``",
)
def _check_disjoint_sets_t_col(self):
"""Check that time column is disjoint from other variable sets."""
t_col_set = {self.t_col}
y_col_set = {self.y_col}
x_cols_set = set(self.x_cols)
d_cols_set = set(self.d_cols)
z_cols_set = set(self.z_cols or [])
id_col_set = {self.id_col}
t_checks_args = [
(y_col_set, "outcome variable", "``y_col``"),
(d_cols_set, "treatment variable", "``d_cols``"),
(x_cols_set, "covariate", "``x_cols``"),
(z_cols_set, "instrumental variable", "``z_cols``"),
(id_col_set, "identifier variable", "``id_col``"),
]
for set1, name, argument in t_checks_args:
self._check_disjoint(
set1=set1,
name1=name,
arg1=argument,
set2=t_col_set,
name2="time variable",
arg2="``t_col``",
)
def _set_id_var(self):
assert_all_finite(self.data.loc[:, self.id_col])
self._id_var = self.data.loc[:, self.id_col]
self._id_var_unique = np.unique(self._id_var.values)
def _set_time_var(self):
"""Set the time variable array."""
if hasattr(self, "_data") and self.t_col in self.data.columns:
t_values = self.data.loc[:, self.t_col]
expected_dtypes = (np.integer, np.floating, np.datetime64)
try:
valid_type = any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes)
except TypeError:
valid_type = False
if not valid_type:
raise ValueError(f"Invalid data type for time variable: expected one of {expected_dtypes}.")
else:
self._t = t_values