Skip to content

Commit 331c929

Browse files
authored
Merge pull request #378 from DoubleML/j-static-panel
PLPR Static Panel model
2 parents aec7006 + fe573a6 commit 331c929

25 files changed

+2034
-68
lines changed

doubleml/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .irm.ssm import DoubleMLSSM
1414
from .plm.lplr import DoubleMLLPLR
1515
from .plm.pliv import DoubleMLPLIV
16+
from .plm.plpr import DoubleMLPLPR
1617
from .plm.plr import DoubleMLPLR
1718
from .utils.blp import DoubleMLBLP
1819
from .utils.policytree import DoubleMLPolicyTree
@@ -43,6 +44,7 @@
4344
"DoubleMLPolicyTree",
4445
"DoubleMLSSM",
4546
"DoubleMLLPLR",
47+
"DoubleMLPLPR",
4648
]
4749

4850
try:

doubleml/data/base_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _data_summary_str(self):
199199

200200
if hasattr(self, "is_cluster_data") and self.is_cluster_data:
201201
data_summary += f"Is cluster data: {self.is_cluster_data}\n"
202-
data_summary += f"No. Observations: {self.n_obs}\n"
202+
data_summary += f"No. Observations: {self.n_obs}"
203203
return data_summary
204204

205205
@classmethod

doubleml/data/panel_data.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ class DoubleMLPanelData(DoubleMLData):
4141
The instrumental variable(s).
4242
Default is ``None``.
4343
44+
static_panel : bool
45+
Indicates whether the data model corresponds to a static
46+
panel data approach (``True``) or to staggered adoption panel data
47+
(``False``). In the latter case, the treatment groups/values are defined in terms of the first time of
48+
treatment exposure.
49+
Default is ``False``.
50+
4451
use_other_treat_as_covariate : bool
4552
Indicates whether in the multiple-treatment case the other treatment variables should be added as covariates.
4653
Default is ``True``.
@@ -82,31 +89,40 @@ def __init__(
8289
id_col,
8390
x_cols=None,
8491
z_cols=None,
92+
static_panel=False,
8593
use_other_treat_as_covariate=True,
8694
force_all_x_finite=True,
8795
datetime_unit="M",
8896
):
8997
DoubleMLBaseData.__init__(self, data)
9098

99+
self._static_panel = static_panel
100+
91101
# we need to set id_col (needs _data) before call to the super __init__ because of the x_cols setter
92102
self.id_col = id_col
93-
self._datetime_unit = _is_valid_datetime_unit(datetime_unit)
94103
self._set_id_var()
95-
96104
# Set time column before calling parent constructor
97105
self.t_col = t_col
106+
self._datetime_unit = _is_valid_datetime_unit(datetime_unit)
107+
108+
if not self.static_panel:
109+
cluster_cols = None
110+
force_all_d_finite = False
111+
else:
112+
cluster_cols = id_col
113+
force_all_d_finite = True
98114

99-
# Call parent constructor
100115
DoubleMLData.__init__(
101116
self,
102117
data=data,
103118
y_col=y_col,
104119
d_cols=d_cols,
105120
x_cols=x_cols,
106121
z_cols=z_cols,
122+
cluster_cols=cluster_cols,
107123
use_other_treat_as_covariate=use_other_treat_as_covariate,
108124
force_all_x_finite=force_all_x_finite,
109-
force_all_d_finite=False,
125+
force_all_d_finite=force_all_d_finite,
110126
)
111127

112128
# reset index to ensure a simple RangeIndex
@@ -115,15 +131,15 @@ def __init__(
115131
# Set time variable array after data is loaded
116132
self._set_time_var()
117133

118-
if self.n_treat != 1:
119-
raise ValueError("Only one treatment column is allowed for panel data.")
120-
121134
self._check_disjoint_sets_id_col()
122135

123136
# intialize the unique values of g and t
124137
self._g_values = np.sort(np.unique(self.d)) # unique values of g
125138
self._t_values = np.sort(np.unique(self.t)) # unique values of t
126139

140+
if self.n_treat != 1:
141+
raise ValueError("Only one treatment column is allowed for panel data.")
142+
127143
def __str__(self):
128144
data_summary = self._data_summary_str()
129145
buf = io.StringIO()
@@ -146,9 +162,10 @@ def _data_summary_str(self):
146162
f"Instrument variable(s): {self.z_cols}\n"
147163
f"Time variable: {self.t_col}\n"
148164
f"Id variable: {self.id_col}\n"
165+
f"Static panel data: {self.static_panel}\n"
149166
)
150167

151-
data_summary += f"No. Unique Ids: {self.n_ids}\n" f"No. Observations: {self.n_obs}\n"
168+
data_summary += f"No. Unique Ids: {self.n_ids}\n" f"No. Observations: {self.n_obs}"
152169
return data_summary
153170

154171
@classmethod
@@ -296,6 +313,11 @@ def n_t_periods(self):
296313
"""
297314
return len(self.t_values)
298315

316+
@property
317+
def static_panel(self):
318+
"""Indicates whether the data model corresponds to a static panel data approach."""
319+
return self._static_panel
320+
299321
def _get_optional_col_sets(self):
300322
base_optional_col_sets = super()._get_optional_col_sets()
301323
id_col_set = {self.id_col}
@@ -370,4 +392,9 @@ def _set_id_var(self):
370392
def _set_time_var(self):
371393
"""Set the time variable array."""
372394
if hasattr(self, "_data") and self.t_col in self.data.columns:
373-
self._t = self.data.loc[:, self.t_col]
395+
t_values = self.data.loc[:, self.t_col]
396+
expected_dtypes = (np.integer, np.floating, np.datetime64)
397+
if not any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes):
398+
raise ValueError(f"Invalid data type for time variable: expected one of {expected_dtypes}.")
399+
else:
400+
self._t = t_values

doubleml/data/tests/test_panel_data.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,26 @@ def test_panel_data_str():
157157
assert "Time variable: t" in dml_str
158158
assert "Id variable: id" in dml_str
159159
assert "No. Observations:" in dml_str
160+
assert "Static panel data:" in dml_str
161+
162+
163+
@pytest.fixture(scope="module", params=[True, False])
164+
def static_panel(request):
165+
return request.param
160166

161167

162168
@pytest.mark.ci
163-
def test_panel_data_properties():
169+
def test_panel_data_properties(static_panel):
164170
np.random.seed(3141)
165171
df = make_did_SZ2020(n_obs=100, return_type="DoubleMLPanelData")._data
166172
dml_data = DoubleMLPanelData(
167-
data=df, y_col="y", d_cols="d", t_col="t", id_col="id", x_cols=[f"Z{i + 1}" for i in np.arange(4)]
173+
data=df,
174+
y_col="y",
175+
d_cols="d",
176+
t_col="t",
177+
id_col="id",
178+
x_cols=[f"Z{i + 1}" for i in np.arange(4)],
179+
static_panel=static_panel,
168180
)
169181

170182
assert np.array_equal(dml_data.id_var, df["id"].values)
@@ -176,3 +188,10 @@ def test_panel_data_properties():
176188
assert dml_data.n_groups == len(np.unique(df["d"].values))
177189
assert np.array_equal(dml_data.t_values, np.sort(np.unique(df["t"].values)))
178190
assert dml_data.n_t_periods == len(np.unique(df["t"].values))
191+
192+
if static_panel:
193+
assert dml_data.static_panel is True
194+
assert dml_data.cluster_cols == ["id"]
195+
else:
196+
assert dml_data.static_panel is False
197+
assert dml_data.cluster_cols is None

doubleml/data/tests/test_panel_data_exceptions.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@ def test_time_col_none_exception(sample_data):
6161
)
6262

6363

64+
@pytest.mark.ci
65+
def test_time_var_data_type_exception(sample_data):
66+
# Test exception when time var is not int, float or datetime
67+
msg = (
68+
r"Invalid data type for time variable: expected one of \(<class 'numpy.integer'>, <class 'numpy.floating'>, "
69+
r"<class 'numpy.datetime64'>\)"
70+
)
71+
with pytest.raises(ValueError, match=msg):
72+
data_time_type = sample_data.copy()
73+
data_time_type["time"] = data_time_type["time"].astype(str)
74+
DoubleMLPanelData(data=data_time_type, y_col="y", d_cols="treatment", t_col="time", id_col="id")
75+
76+
6477
@pytest.mark.ci
6578
def test_overlapping_variables_exception(sample_data):
6679
# Test exception when id_col overlaps with another variable
@@ -80,9 +93,10 @@ def test_overlapping_variables_exception(sample_data):
8093
DoubleMLPanelData(data=sample_data, y_col="y", d_cols="id", t_col="time", id_col="id") # Using id as treatment
8194

8295
# Test time variable overlapping
96+
# using t_col="id", id_col="id" gives invalid data type for time variable exception first
8397
msg = r"At least one variable/column is set as time variable \(``t_col``\) and identifier variable \(``id_col``\)."
8498
with pytest.raises(ValueError, match=msg):
85-
DoubleMLPanelData(data=sample_data, y_col="y", d_cols="treatment", t_col="id", id_col="id") # Using id as time
99+
DoubleMLPanelData(data=sample_data, y_col="y", d_cols="treatment", t_col="time", id_col="time") # Using time as id
86100

87101

88102
@pytest.mark.ci

doubleml/double_ml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def __str__(self):
192192

193193
additional_info = self._format_additional_info_str()
194194
if additional_info:
195-
representation += f"\n\n------------------ Additional Information ------------------\n" f"{additional_info}"
195+
representation += f"\n\n------------------ Additional Information -------------\n" f"{additional_info}"
196196
return representation
197197

198198
@property
@@ -1177,7 +1177,7 @@ def _initialize_ml_nuisance_params(self):
11771177
pass
11781178

11791179
@abstractmethod
1180-
def _nuisance_est(self, smpls, n_jobs_cv, return_models, external_predictions):
1180+
def _nuisance_est(self, smpls, n_jobs_cv, external_predictions, return_models):
11811181
pass
11821182

11831183
@abstractmethod

doubleml/plm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .lplr import DoubleMLLPLR
66
from .pliv import DoubleMLPLIV
7+
from .plpr import DoubleMLPLPR
78
from .plr import DoubleMLPLR
89

9-
__all__ = ["DoubleMLPLR", "DoubleMLPLIV", "DoubleMLLPLR"]
10+
__all__ = ["DoubleMLPLR", "DoubleMLPLIV", "DoubleMLLPLR", "DoubleMLPLPR"]

doubleml/plm/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .dgp_lplr_LZZ2020 import make_lplr_LZZ2020
88
from .dgp_pliv_CHS2015 import make_pliv_CHS2015
99
from .dgp_pliv_multiway_cluster_CKMS2021 import make_pliv_multiway_cluster_CKMS2021
10+
from .dgp_plpr_CP2025 import make_plpr_CP2025
1011
from .dgp_plr_CCDDHNR2018 import make_plr_CCDDHNR2018
1112
from .dgp_plr_turrell2018 import make_plr_turrell2018
1213

@@ -18,4 +19,5 @@
1819
"make_pliv_multiway_cluster_CKMS2021",
1920
"make_lplr_LZZ2020",
2021
"_make_pliv_data",
22+
"make_plpr_CP2025",
2123
]

0 commit comments

Comments
 (0)