Skip to content

Commit 6d34f01

Browse files
neuralsorcerermeta-codesync[bot]
authored andcommitted
Support Pandas 3 (#302)
Summary: Pull Request resolved: #302 - Closes #297 Pull Request resolved: #301 Differential Revision: D92157858 Pulled By: talgalili fbshipit-source-id: ecf3b8cd3df62717755549ba13831bd4ce3dd75a
1 parent 3f405ef commit 6d34f01

17 files changed

+196
-57
lines changed

CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,22 @@
1111
- Added paired outcome-weight impact tests (`y*w0` vs `y*w1`) with confidence intervals.
1212
- Exposed in `BalanceDFOutcomes`, `Sample.diagnostics()`, and the CLI via
1313
`--weights_impact_on_outcome_method`.
14+
- **Pandas 3 support**
15+
- Updated compatibility and tests for pandas 3.x
1416

1517
## Bug Fixes
1618

1719
- **Removed deprecated setup build**
1820
- Replaced deprecated `setup.py` with `pyproject.toml` build in CI to avoid build failure.
1921
- **Hardened ID column candidate validation**
2022
- `guess_id_column()` now ignores duplicate candidate names and validates that candidates are non-empty strings.
23+
- **Hardened pandas 3 compatibility paths**
24+
- Updated string/NA handling and discrete checks for pandas 3 dtypes, and refreshed tests to accept string-backed dtypes.
2125

2226
## Packaging & Tests
2327

24-
- **Pandas 2.x compatibility and upper bound (<3.0.0)**
25-
- Constrained the pandas dependency to `>=2,<3.0.0` to avoid untested pandas 3.x API and dtype changes.
28+
- **Pandas 3.x compatibility**
29+
- Expanded the pandas dependency range to allow pandas 3.x releases.
2630

2731
## Breaking Changes
2832

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ REQUIRES = [
6969
# Numpy and pandas: carefully versioned for binary compatibility
7070
"numpy>=1.21.0,<2.0; python_version<'3.12'",
7171
"numpy>=1.24.0; python_version>='3.12'",
72-
"pandas>=1.5.0,<2.4.0; python_version<'3.12'",
73-
"pandas>=2.0.0; python_version>='3.12'",
72+
"pandas>=1.5.0,<4.0.0; python_version<'3.12'",
73+
"pandas>=2.0.0,<4.0.0; python_version>='3.12'",
7474
# Scientific stack
7575
"scipy>=1.7.0,<1.14.0; python_version<'3.12'",
7676
"scipy>=1.11.0; python_version>='3.12'",

balance/adjustment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def trim_weights(
211211
original_name = getattr(weights, "name", None)
212212

213213
if isinstance(weights, pd.Series):
214-
weights = weights.astype(np.float64, copy=False)
214+
weights = weights.astype(np.float64)
215215
elif isinstance(weights, (np.ndarray, list, tuple)):
216216
weights = pd.Series(
217217
np.asarray(weights, dtype=np.float64), dtype=np.float64, name=original_name

balance/balancedf_class.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2372,7 +2372,7 @@ def summary(
23722372
target_clause = f"Response rates (in the target):\n {target_response_rates}"
23732373

23742374
n_outcomes = self.df.shape[1]
2375-
list_outcomes = self.df.columns.values
2375+
list_outcomes = np.array(self.df.columns, dtype=object)
23762376
mean_outcomes_with_ci = mean_outcomes_with_ci
23772377
relative_response_rates = relative_response_rates
23782378
target_clause = target_clause
@@ -2458,6 +2458,18 @@ def __init__(self: "BalanceDFWeights", sample: Sample) -> None:
24582458
"""
24592459
super().__init__(sample.weight_column.to_frame(), sample, name="weights")
24602460

2461+
@property
2462+
def df(self: "BalanceDFWeights") -> pd.DataFrame:
2463+
"""Return the current weight column as a DataFrame.
2464+
2465+
Args:
2466+
self (BalanceDFWeights): The BalanceDFWeights instance.
2467+
2468+
Returns:
2469+
pd.DataFrame: DataFrame containing the current weight column.
2470+
"""
2471+
return self._sample.weight_column.to_frame()
2472+
24612473
# TODO: maybe add better control if there are no weights for unadjusted or target (the current default shows them in the legend, but not in the figure)
24622474
def plot(
24632475
self: "BalanceDFWeights", on_linked_samples: bool = True, **kwargs: Any

balance/sample_class.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import inspect
1212
import logging
1313
from copy import deepcopy
14+
from importlib.metadata import version as importlib_version
1415
from typing import Any, Callable, Dict, List, Literal
1516

1617
import numpy as np
@@ -499,16 +500,23 @@ def from_frame(
499500
# for x in df.columns:
500501
# if (is_numeric_dtype(df[x])) and (not is_bool_dtype(df[x])):
501502
# df[x] = df[x].astype("float64")
502-
input_type = ["Int64", "Int32", "int64", "int32", "int16", "int8", "string"]
503+
input_type = ["Int64", "Int32", "int64", "int32", "int16", "int8"]
503504
output_type = [
504505
"float64",
505506
"float32", # This changes Int32Dtype() into dtype('int32') (from pandas to numpy)
506507
"float64",
507508
"float32",
508509
"float16",
509510
"float16", # Using float16 since float8 doesn't exist, see: https://stackoverflow.com/a/40507235/256662
510-
"object",
511511
]
512+
# TODO:(after 2026) that if pandas >=3, this doesn't cause issues for users importing data from SQL
513+
# In pandas < 3, convert string dtype to object for compatibility
514+
_pd_version = tuple(
515+
int(x) for x in importlib_version("pandas").split(".")[:2]
516+
)
517+
if _pd_version < (3, 0):
518+
input_type.append("string")
519+
output_type.append("object")
512520
for i_input, i_output in zip(input_type, output_type):
513521
sample._df = balance_util._pd_convert_all_types(
514522
sample._df, i_input, i_output
@@ -1122,7 +1130,8 @@ def set_weights(self, weights: pd.Series | float | None) -> None:
11221130
].astype("float64")
11231131

11241132
# Now assign the weights
1125-
self._df.loc[:, self.weight_column.name] = weights
1133+
weights_value = np.nan if weights is None else weights
1134+
self._df.loc[:, self.weight_column.name] = weights_value
11261135

11271136
self.weight_column = self._df[self.weight_column.name]
11281137

balance/stats_and_plots/weighted_stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _prepare_weighted_stat_args(
8787

8888
dtypes = v.dtypes if hasattr(v.dtypes, "__iter__") else [v.dtypes]
8989

90-
if not all(np.issubdtype(x, np.number) for x in dtypes):
90+
if not all(pd.api.types.is_numeric_dtype(x) for x in dtypes):
9191
raise TypeError("all columns must be numeric")
9292

9393
if inf_rm:

balance/testutil.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ def assertEqual(
157157
lazy: bool = kwargs.get("lazy", False)
158158
if isinstance(first, np.ndarray) or isinstance(second, np.ndarray):
159159
np.testing.assert_array_equal(first, second, **kwargs)
160+
elif isinstance(first, pd.api.extensions.ExtensionArray) or isinstance(
161+
second, pd.api.extensions.ExtensionArray
162+
):
163+
np.testing.assert_array_equal(np.array(first), np.array(second), **kwargs)
160164
elif isinstance(first, pd.DataFrame) or isinstance(second, pd.DataFrame):
161165
_assert_frame_equal_lazy(
162166
first,

balance/utils/data_transformation.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def add_na_indicator(
6767
filled_col = (
6868
df[c].cat.add_categories(replace_val_obj).fillna(replace_val_obj)
6969
)
70-
df[c] = filled_col.infer_objects(copy=False)
70+
df[c] = filled_col.infer_objects()
7171
elif c in non_numeric_cols:
7272
df[c] = _safe_fillna_and_infer(df[c], replace_val_obj)
7373
else:
@@ -319,19 +319,21 @@ def fct_lump(s: pd.Series, prop: float = 0.05) -> pd.Series:
319319
props = s.value_counts() / s.shape[0]
320320

321321
# Ensure proper dtype inference on the index
322-
props.index = props.index.infer_objects(copy=False)
322+
props.index = props.index.infer_objects()
323323

324324
small_categories = props[props < prop].index.tolist()
325325

326326
remainder_category_name = "_lumped_other"
327327
while remainder_category_name in props.index:
328328
remainder_category_name = remainder_category_name * 2
329329

330-
# Convert to object dtype
331-
s = s.astype("object")
330+
# Convert to object dtype unless already string dtype
331+
if not pd.api.types.is_string_dtype(s.dtype):
332+
s = s.astype("object")
332333

333334
# Replace small categories with the remainder category name
334-
s.loc[s.apply(lambda x: x in small_categories)] = remainder_category_name
335+
mask = s.isin(small_categories).fillna(False)
336+
s.loc[mask] = remainder_category_name
335337
return s
336338

337339

@@ -349,12 +351,12 @@ def fct_lump_by(s: pd.Series, by: pd.Series, prop: float = 0.05) -> pd.Series:
349351
pd.Series: pd.series, we keep the index of s as the index of the result.
350352
"""
351353
res = copy.deepcopy(s)
352-
pd.options.mode.copy_on_write = True
353354
# pandas groupby doesnt preserve order
354355
for subgroup in pd.unique(by):
355356
mask = by == subgroup
356357
grouped_res = fct_lump(res.loc[mask], prop=prop)
357358
# Ensure dtype compatibility before assignment
358-
res = res.astype("object")
359+
if not pd.api.types.is_string_dtype(res.dtype):
360+
res = res.astype("object")
359361
res.loc[mask] = grouped_res
360362
return res

balance/utils/input_validation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def _is_discrete_series(series: pd.Series) -> bool:
199199
return (
200200
is_binary_indicator
201201
or pd.api.types.is_object_dtype(series)
202+
or pd.api.types.is_string_dtype(series)
202203
or isinstance(series.dtype, pd.CategoricalDtype)
203204
or pd.api.types.is_bool_dtype(series)
204205
)
@@ -351,6 +352,7 @@ def _is_arraylike(o: Any) -> bool:
351352
return (
352353
isinstance(o, np.ndarray)
353354
or isinstance(o, pd.Series)
355+
or isinstance(o, pd.api.extensions.ExtensionArray)
354356
or (
355357
hasattr(pd.arrays, "NumpyExtensionArray")
356358
and isinstance(o, pd.arrays.NumpyExtensionArray)
@@ -400,7 +402,9 @@ def _return_type_creation_function(x: Any) -> Callable | Any:
400402
if isinstance(x, np.ndarray):
401403
return lambda obj: np.array(obj, dtype=x.dtype)
402404
# same with pd.arrays.PandasArray, pd.arrays.StringArray, etc.
403-
elif "pandas.core.arrays" in str(type(x)):
405+
elif isinstance(x, pd.api.extensions.ExtensionArray) or (
406+
"pandas.core.arrays" in str(type(x))
407+
):
404408
return lambda obj: pd.array(obj, dtype=x.dtype)
405409
else:
406410
return type(x)

balance/utils/model_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def _prepare_input_model_matrix(
378378
if fix_columns_names:
379379
all_data.columns = all_data.columns.str.replace(
380380
r"[^\w]", "_", regex=True
381-
).infer_objects(copy=False)
381+
).infer_objects()
382382
all_data = _make_df_column_names_unique(all_data)
383383

384384
return {"all_data": all_data, "sample_n": sample_n}

0 commit comments

Comments
 (0)