Skip to content

Commit 904aa7d

Browse files
CristianLarameta-codesync[bot]
authored andcommitted
Pandas: Upgrade to 3.0 and fix compatibility issues (#4838)
Summary: Pandas 3.0 introduces several breaking changes (https://pandas.pydata.org/docs/whatsnew/v3.0.0.html#other-api-changes) that required fixes across the codebase: 1. StringDtype inference: Pandas 3.0 infers string columns to use StringDtype instead of object dtype. This breaks DataFrame comparisons since our Data class expects object dtype (defined in COLUMN_DATA_TYPES). Fixed by setting `pd.options.future.infer_string = False` in modules that construct DataFrames. 2. Deprecated inplace on replace(): The `inplace=True` parameter on Series.replace() is deprecated. Changed to assignment pattern: `df["col"] = df["col"].replace(...)`. 3. read_json() no longer accepts strings: pd.read_json() now requires file paths or file-like objects, not raw JSON strings. Wrapped JSON strings with StringIO(). 4. Read-only arrays from DataFrame.to_numpy(): Arrays returned by to_numpy() are now read-only views. Added .copy() before passing to torch.from_numpy() which requires writable arrays. 5. Test dtype mismatches: Tests comparing DataFrames failed because manually constructed expected DataFrames had different dtypes than production code. Fixed by wrapping expected DataFrames with Data(df=...).df to ensure consistent dtype casting. Pull Request resolved: #4838 Reviewed By: mgrange1998 Differential Revision: D91825190 Pulled By: CristianLara fbshipit-source-id: 242e097f1f3c90c4c9bb18cb7dc68ea4609adecd
1 parent 2ea3fd1 commit 904aa7d

File tree

14 files changed

+302
-217
lines changed

14 files changed

+302
-217
lines changed

ax/adapter/data_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
from __future__ import annotations
1717

18+
import functools
1819
import warnings
1920
from collections.abc import Iterable
2021
from copy import deepcopy
2122
from dataclasses import dataclass, InitVar
22-
from typing import Any
23+
from typing import Any, Callable
2324

2425
import numpy as np
26+
import pandas as pd
2527
from ax.core.data import Data, MAP_KEY
2628
from ax.core.experiment import Experiment
2729
from ax.core.map_metric import MapMetric
@@ -351,6 +353,32 @@ def extract_experiment_data(
351353
return ExperimentData(arm_data=arm_data, observation_data=observation_data)
352354

353355

356+
def _use_object_dtype_for_strings(
357+
func: Callable[..., Any],
358+
) -> Callable[..., Any]:
359+
"""Decorator to disable pandas 3.0 StringDtype inference.
360+
361+
This ensures string columns like arm_name keep object dtype to match
362+
Data.COLUMN_DATA_TYPES. See: https://pandas.pydata.org/docs/whatsnew/v3.0.0.html
363+
364+
On older pandas versions that don't have the future.infer_string option,
365+
this decorator is a no-op since the StringDtype inference doesn't exist.
366+
"""
367+
368+
@functools.wraps(func)
369+
def wrapper(*args: Any, **kwargs: Any) -> Any:
370+
# Check if the future.infer_string option exists (pandas 3.0+)
371+
if hasattr(pd.options, "future") and hasattr(pd.options.future, "infer_string"):
372+
with pd.option_context("future.infer_string", False):
373+
return func(*args, **kwargs)
374+
else:
375+
# Older pandas version - no StringDtype inference to disable
376+
return func(*args, **kwargs)
377+
378+
return wrapper
379+
380+
381+
@_use_object_dtype_for_strings
354382
def _extract_arm_data(experiment: Experiment) -> DataFrame:
355383
"""Extract a dataframe containing the trial index, arm name,
356384
parameterizations, and metadata from the given experiment.
@@ -383,6 +411,7 @@ def _extract_arm_data(experiment: Experiment) -> DataFrame:
383411
return df
384412

385413

414+
@_use_object_dtype_for_strings
386415
def _extract_observation_data(
387416
experiment: Experiment,
388417
data_loader_config: DataLoaderConfig,

ax/adapter/tests/test_data_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from unittest import mock
1212

1313
import numpy as np
14-
from ax.adapter.data_utils import DataLoaderConfig, extract_experiment_data
14+
from ax.adapter.data_utils import (
15+
_use_object_dtype_for_strings,
16+
DataLoaderConfig,
17+
extract_experiment_data,
18+
)
1519
from ax.adapter.registry import Generators
1620
from ax.core.data import Data, MAP_KEY
1721
from ax.core.observation import Observation, ObservationData, ObservationFeatures
@@ -97,6 +101,7 @@ def test_extract_experiment_data_empty(self) -> None:
97101
)
98102
self.assertEqual(experiment_data, experiment_data)
99103

104+
@_use_object_dtype_for_strings
100105
def test_extract_experiment_data_non_map(self) -> None:
101106
# This is a 2 objective experiment with 2 trials, 1 arm each.
102107
observations = [[0.1, 1.0], [0.2, 2.0]]
@@ -248,6 +253,7 @@ def test_extract_experiment_data_non_map(self) -> None:
248253
)
249254
)
250255

256+
@_use_object_dtype_for_strings
251257
def test_extract_experiment_data_map(self) -> None:
252258
exp = get_branin_experiment_with_timestamp_map_metric(with_trials_and_data=True)
253259
t_0_metric = 55.602112642270264
@@ -261,7 +267,8 @@ def test_extract_experiment_data_map(self) -> None:
261267
expected_arm_df = DataFrame(
262268
[{"x1": 0.0, "x2": 0.0}, {"x1": 1.0, "x2": 1.0}],
263269
index=MultiIndex.from_tuples(
264-
[(0, "0_0"), (1, "1_0")], names=["trial_index", "arm_name"]
270+
[(0, "0_0"), (1, "1_0")],
271+
names=["trial_index", "arm_name"],
265272
),
266273
)
267274
assert_frame_equal(
@@ -359,6 +366,7 @@ def test_extract_experiment_data_map(self) -> None:
359366
# Check equality with self.
360367
self.assertEqual(experiment_data, experiment_data)
361368

369+
@_use_object_dtype_for_strings
362370
def test_extract_experiment_data_multiple_map(self) -> None:
363371
# Checks that multiple map metrics are correctly normalized.
364372
# Using a custom Data input to simplify testing.
@@ -467,6 +475,7 @@ def test_extract_experiment_data_batch_trials(self) -> None:
467475
for df in [experiment_data.arm_data, experiment_data.observation_data]:
468476
self.assertEqual(set(df.index.get_level_values("arm_name")), expected_arms)
469477

478+
@_use_object_dtype_for_strings
470479
def test_extract_experiment_data_with_metadata_columns(self) -> None:
471480
# Tests the case where the Data.df includes additional columns,
472481
# such as start_time and end_time, besides the usual required columns.
@@ -522,7 +531,7 @@ def test_extract_experiment_data_with_metadata_columns(self) -> None:
522531
names=["trial_index", "arm_name"],
523532
),
524533
columns=MultiIndex.from_tuples(
525-
tuples=[
534+
[
526535
("mean", "branin_a"),
527536
("mean", "branin_b"),
528537
("sem", "branin_a"),

ax/adapter/transforms/tests/test_cast_transform.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
import numpy as np
1313
from ax.adapter.base import DataLoaderConfig
14-
from ax.adapter.data_utils import ExperimentData, extract_experiment_data
14+
from ax.adapter.data_utils import (
15+
_use_object_dtype_for_strings,
16+
ExperimentData,
17+
extract_experiment_data,
18+
)
1519
from ax.adapter.transforms.cast import Cast
1620
from ax.core.observation import Observation, ObservationData, ObservationFeatures
1721
from ax.core.parameter import (
@@ -449,6 +453,7 @@ def test_transform_experiment_data_flatten_with_missing_columns(self) -> None:
449453
)
450454
self.assertEqual(set(transformed.arm_data.columns), expected_columns)
451455

456+
@_use_object_dtype_for_strings
452457
def test_transform_experiment_data_cast(self) -> None:
453458
# Test for casting to the correct data type and dropping of Nones.
454459
experiment = get_experiment_with_observations(
@@ -495,6 +500,7 @@ def test_transform_experiment_data_cast(self) -> None:
495500
]
496501
assert_frame_equal(transformed.observation_data, expected_obs_data)
497502

503+
@_use_object_dtype_for_strings
498504
def test_transform_experiment_data_cast_map_data(self) -> None:
499505
# Check that indexing for removal of NaNs works correctly with data that
500506
# has a "step" column.

ax/adapter/transforms/tests/test_choice_encode_transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from copy import deepcopy
1010

1111
from ax.adapter.base import DataLoaderConfig
12-
from ax.adapter.data_utils import extract_experiment_data
12+
from ax.adapter.data_utils import _use_object_dtype_for_strings, extract_experiment_data
1313
from ax.adapter.transforms.choice_encode import (
1414
ChoiceToNumericChoice,
1515
OrderedChoiceToIntegerRange,
@@ -258,6 +258,7 @@ def test_hss_dependents_are_preserved(self) -> None:
258258
self.assertEqual(hss.parameters["x2"].parameter_type, ParameterType.INT)
259259
self.assertEqual(hss.parameters["x2"].dependents, {0: [], 1: ["x3"]})
260260

261+
@_use_object_dtype_for_strings
261262
def test_transform_experiment_data(self) -> None:
262263
parameterizations = [
263264
{"x": 2.2, "a": 2, "b": 10.0, "c": 10.0, "d": "r", "e": "q"},

ax/adapter/transforms/tests/test_one_hot_transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from copy import deepcopy
1010

1111
from ax.adapter.base import DataLoaderConfig
12-
from ax.adapter.data_utils import extract_experiment_data
12+
from ax.adapter.data_utils import _use_object_dtype_for_strings, extract_experiment_data
1313
from ax.adapter.transforms.one_hot import OH_PARAM_INFIX, OneHot
1414
from ax.core.observation import ObservationFeatures
1515
from ax.core.parameter import (
@@ -226,6 +226,7 @@ def test_heterogeneous_search_space(self) -> None:
226226
untf_obs = self.t.untransform_observation_features(obs_ft)
227227
self.assertFalse(any(obs.parameters.get("b") == "b" for obs in untf_obs))
228228

229+
@_use_object_dtype_for_strings
229230
def test_transform_experiment_data(self) -> None:
230231
parameterizations = [
231232
{"x": 2.2, "a": 2, "b": "b", "c": False, "d": 10.0},

ax/adapter/transforms/tests/test_unit_x_transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from copy import deepcopy
1010

1111
from ax.adapter.base import DataLoaderConfig
12-
from ax.adapter.data_utils import extract_experiment_data
12+
from ax.adapter.data_utils import _use_object_dtype_for_strings, extract_experiment_data
1313
from ax.adapter.transforms.unit_x import UnitX
1414
from ax.core.observation import ObservationFeatures
1515
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
@@ -205,6 +205,7 @@ def test_TransformNewSearchSpace(self) -> None:
205205
t.transform_search_space(new_search_space_with_target)
206206
self.assertEqual(new_search_space_with_target.parameters["x"].target_value, 0.5)
207207

208+
@_use_object_dtype_for_strings
208209
def test_transform_experiment_data(self) -> None:
209210
parameterizations = [
210211
{"x": 1.0, "y": 1.5, "z": 1.0, "a": 1, "b": "b"},

0 commit comments

Comments
 (0)