Skip to content

Commit cb963d0

Browse files
authored
Merge pull request #385 from DoubleML/sk-update-pandas-3
Sk-update-pandas-3
2 parents 331c929 + 4a5ea0d commit cb963d0

File tree

8 files changed

+18
-10
lines changed

8 files changed

+18
-10
lines changed

doubleml/data/panel_data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,11 @@ def _set_time_var(self):
394394
if hasattr(self, "_data") and self.t_col in self.data.columns:
395395
t_values = self.data.loc[:, self.t_col]
396396
expected_dtypes = (np.integer, np.floating, np.datetime64)
397-
if not any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes):
397+
try:
398+
valid_type = any(np.issubdtype(t_values.dtype, dt) for dt in expected_dtypes)
399+
except TypeError:
400+
valid_type = False
401+
if not valid_type:
398402
raise ValueError(f"Invalid data type for time variable: expected one of {expected_dtypes}.")
399403
else:
400404
self._t = t_values

doubleml/did/utils/_plot.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def add_jitter(data, x_col, is_datetime=None, jitter_value=None):
2525
is_datetime = pd.api.types.is_datetime64_any_dtype(data[x_col])
2626

2727
# Initialize jittered_x with original values
28-
data["jittered_x"] = data[x_col]
28+
if is_datetime:
29+
data["jittered_x"] = data[x_col]
30+
else:
31+
data["jittered_x"] = data[x_col].astype(float)
2932

3033
for x_val in data[x_col].unique():
3134
mask = data[x_col] == x_val

doubleml/did/utils/tests/test_add_jitter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import datetime, timedelta
22

3+
import numpy as np
34
import pandas as pd
45
import pytest
56

@@ -41,7 +42,7 @@ def test_add_jitter_numeric_no_duplicates(numeric_df_no_duplicates):
4142
"""Test that no jitter is added when there are no duplicates."""
4243
result = add_jitter(numeric_df_no_duplicates, "x")
4344
# No jitter should be added when there are no duplicates
44-
pd.testing.assert_series_equal(result["jittered_x"], result["x"], check_names=False)
45+
np.testing.assert_allclose(result["jittered_x"], result["x"])
4546

4647

4748
@pytest.mark.ci
@@ -121,7 +122,7 @@ def test_add_jitter_explicit_datetime_flag():
121122
df = pd.DataFrame({"x": ["2023-01-01", "2023-01-01", "2023-01-02"], "y": [10, 15, 20]})
122123

123124
# Without specifying is_datetime, it would treat as strings
124-
with pytest.raises(TypeError):
125+
with pytest.raises(ValueError):
125126
_ = add_jitter(df, "x")
126127

127128
# With is_datetime=True, it should convert and jitter as datetimes

doubleml/irm/tests/test_apo_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_apo_exception_data():
2525
msg = (
2626
r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or "
2727
r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type "
28-
r"<class 'pandas\.core\.frame\.DataFrame'> was passed\."
28+
r"<class 'pandas\..*DataFrame'> was passed\."
2929
)
3030
with pytest.raises(TypeError, match=msg):
3131
_ = DoubleMLAPO(pd.DataFrame(), ml_g, ml_m, treatment_level=0)

doubleml/irm/tests/test_ssm_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_ssm_exception_data():
3535
msg = (
3636
r"The data must be of DoubleMLData or DoubleMLClusterData or DoubleMLDIDData or DoubleMLSSMData or "
3737
r"DoubleMLRDDData type\. Empty DataFrame\nColumns: \[\]\nIndex: \[\] of type "
38-
r"<class 'pandas\.core\.frame\.DataFrame'> was passed\."
38+
r"<class 'pandas\..*DataFrame'> was passed\."
3939
)
4040
with pytest.raises(TypeError, match=msg):
4141
_ = DoubleMLSSM(pd.DataFrame(), ml_g, ml_pi, ml_m)

doubleml/plm/tests/test_lplr_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
@pytest.mark.ci
2525
def test_lplr_exception_data():
26-
msg = r"The data must be of DoubleMLData.* type\.[\s\S]* of type " r"<class 'pandas\.core\.frame\.DataFrame'> was passed\."
26+
msg = r"The data must be of DoubleMLData.*type\."
2727
with pytest.raises(TypeError, match=msg):
2828
_ = DoubleMLLPLR(pd.DataFrame(), ml_M, ml_t, ml_m)
2929

doubleml/plm/tests/test_plpr_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262
@pytest.mark.ci
6363
def test_plpr_exception_data():
64-
msg = "The data must be of DoubleMLPanelData type. <class 'pandas.core.frame.DataFrame'> was passed."
64+
msg = r"The data must be of DoubleMLPanelData type. <class 'pandas\..*DataFrame'> was passed."
6565
with pytest.raises(TypeError, match=msg):
6666
_ = dml.DoubleMLPLPR(pd.DataFrame(), ml_l, ml_m)
6767
# not a panel data object

doubleml/utils/tests/test_policytree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def test_doubleml_exception_policytree():
9898
with pytest.raises(TypeError, match=msg):
9999
dml_policytree_predict.predict(features=1)
100100
msg = (
101-
r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype\=\'object\'\). "
102-
r"Features with keys Index\(\[\'d\'\], dtype=\'object\'\) were passed."
101+
r"The features must have the keys Index\(\[\'a\', \'b\', \'c\'\], dtype=.*?\)\. "
102+
r"Features with keys Index\(\[\'d\'\], dtype=.*?\) were passed\."
103103
)
104104
with pytest.raises(KeyError, match=msg):
105105
dml_policytree_predict.predict(features=pd.DataFrame({"d": [3, 4]}))

0 commit comments

Comments
 (0)