Skip to content

Commit 6ee535e

Browse files
authored
Merge pull request #405 from network-wrangler/hotfix-pandaspanderacompatibility
fix: resolve pandas 2.x, pandera 0.24.0, pydantic 2.x compatibility i…
2 parents a5a743b + 9400b37 commit 6ee535e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+568
-412
lines changed

environments/conda/dev-environment.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies:
1111
- ipywidgets=7.8.4
1212
- osmnx=1.9.3
1313
- pandas=2.2.3
14-
- pandera-geopandas=0.18.0
14+
- pandera-geopandas=0.24.0
1515
- psutil=6.0.0
1616
- pyarrow=17.0.0
1717
- pydantic=2.9.2
@@ -47,5 +47,5 @@ dependencies:
4747
- mkdocs-mermaid2-plugin==1.1.1
4848
- mkdocstrings==0.26.1
4949
- mkdocstrings-python==1.11.1
50-
- pandera==0.20.4
51-
- projectcard>=0.3.3
50+
- pandera[geopandas]==0.24.0
51+
- projectcard==0.3.3

environments/conda/environment.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@ dependencies:
2323

2424
# Pip-installed dependencies
2525
- pip:
26-
- pandera==0.20.4
27-
- projectcard>=0.3.3
26+
- pandera[geopandas]==0.24.0
27+
- projectcard==0.3.3

environments/pip/requirements-lock.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
fiona==1.10.1
12
folium==0.17.0
23
geographiclib==2.0
34
geojson==3.1.0
@@ -7,9 +8,11 @@ jupyter==5.7.2
78
notebook==7.2.2
89
osmnx==1.9.3
910
pandas==2.2.3
10-
pandera-geopandas==0.18.0
11+
pandera[pandas,geopandas]==0.24.0
12+
projectcard==0.3.3
1113
psutil==6.0.0
1214
pyarrow==17.0.0
1315
pydantic==2.9.2
1416
pyogrio==0.9.0
1517
pyyaml==6.0.2
18+
typing-extensions==4.12.2

examples/stpaul/clean_network.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
"\n",
2222
"from network_wrangler import load_roadway_from_dir, load_transit, write_roadway, write_transit\n",
2323
"from network_wrangler.models.gtfs.tables import (\n",
24+
" RoutesTable,\n",
2425
" WranglerFrequenciesTable,\n",
2526
" WranglerShapesTable,\n",
2627
" WranglerStopsTable,\n",
2728
" WranglerStopTimesTable,\n",
2829
" WranglerTripsTable,\n",
29-
" RoutesTable,\n",
3030
")\n",
3131
"from network_wrangler.models.roadway.tables import RoadLinksTable, RoadNodesTable, RoadShapesTable"
3232
]

network_wrangler/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Network Wrangler Package."""
22

3-
__version__ = "1.0-beta.2"
3+
__version__ = "1.0-beta.3"
44

55
import warnings
66

@@ -17,17 +17,17 @@
1717
from .utils.df_accessors import *
1818

1919
__all__ = [
20+
"Scenario",
2021
"WranglerLogger",
21-
"setup_logging",
22-
"load_transit",
23-
"write_transit",
22+
"create_scenario",
2423
"load_roadway",
2524
"load_roadway_from_dir",
26-
"write_roadway",
27-
"create_scenario",
28-
"Scenario",
29-
"load_wrangler_config",
3025
"load_scenario",
26+
"load_transit",
27+
"load_wrangler_config",
28+
"setup_logging",
29+
"write_roadway",
30+
"write_transit",
3131
]
3232

3333

network_wrangler/models/_base/db.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -414,20 +414,44 @@ def __deepcopy__(self, memo):
414414

415415
# Copy all attributes to the new instance
416416
for attr_name, attr_value in self.__dict__.items():
417-
# Use copy.deepcopy to create deep copies of mutable objects
418-
if isinstance(attr_value, pd.DataFrame):
419-
setattr(new_instance, attr_name, copy.deepcopy(attr_value, memo))
417+
# Handle pandera DataFrameModel objects specially
418+
if (
419+
hasattr(attr_value, "__class__")
420+
and hasattr(attr_value.__class__, "__name__")
421+
and "DataFrameModel" in attr_value.__class__.__name__
422+
):
423+
# For pandera DataFrameModel objects, copy the underlying DataFrame and recreate the model
424+
# This avoids the timestamp corruption issue with copy.deepcopy()
425+
try:
426+
# Get the underlying DataFrame
427+
if hasattr(attr_value, "_obj"):
428+
df_copy = attr_value._obj.copy(deep=True)
429+
elif hasattr(attr_value, "data"):
430+
df_copy = attr_value.data.copy(deep=True)
431+
else:
432+
# For newer pandera versions, try direct access
433+
df_copy = attr_value.copy(deep=True)
434+
435+
# Recreate the DataFrameModel object with the copied DataFrame
436+
new_table = attr_value.__class__(df_copy)
437+
438+
setattr(new_instance, attr_name, new_table)
439+
except Exception as e:
440+
# Fallback to regular deep copy if the above fails
441+
setattr(new_instance, attr_name, copy.deepcopy(attr_value, memo))
442+
elif isinstance(attr_value, pd.DataFrame):
443+
# For plain pandas DataFrames, use deep copy
444+
setattr(new_instance, attr_name, attr_value.copy(deep=True))
420445
else:
421-
setattr(new_instance, attr_name, attr_value)
422-
423-
WranglerLogger.warning(
424-
"Creating a deep copy of db object.\
425-
This will NOT update any references (e.g. from TransitNetwork)"
426-
)
446+
# For all other objects, use regular deep copy
447+
setattr(new_instance, attr_name, copy.deepcopy(attr_value, memo))
427448

428-
# Return the newly created deep copy instance of the object
429449
return new_instance
430450

431451
def deepcopy(self):
432452
"""Convenience method to exceute deep copy of instance."""
433453
return copy.deepcopy(self)
454+
455+
def __hash__(self):
456+
"""Hash based on the hashes of the tables in table_names."""
457+
return hash(tuple((name, self.get_table(name).to_csv()) for name in self.table_names))
Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,55 @@
11
from __future__ import annotations
22

33
from datetime import time
4-
from typing import Annotated, Any, Literal, TypeVar, Union
4+
from typing import Any, Literal, TypeVar, Union
55

66
import pandas as pd
7-
from pydantic import (
8-
BeforeValidator,
9-
Field,
10-
)
117

128
GeoFileTypes = Literal["json", "geojson", "shp", "parquet", "csv", "txt"]
139

1410
TransitFileTypes = Literal["txt", "csv", "parquet"]
1511

16-
1712
RoadwayFileTypes = Literal["geojson", "shp", "parquet", "json"]
1813

19-
2014
PandasDataFrame = TypeVar("PandasDataFrame", bound=pd.DataFrame)
2115
PandasSeries = TypeVar("PandasSeries", bound=pd.Series)
2216

17+
ForcedStr = Any # For simplicity, since BeforeValidator is not used here
2318

24-
ForcedStr = Annotated[Any, BeforeValidator(lambda x: str(x))]
19+
OneOf = list[list[Union[str, list[str]]]]
20+
ConflictsWith = list[list[str]]
21+
AnyOf = list[list[Union[str, list[str]]]]
2522

23+
Latitude = float
24+
Longitude = float
25+
PhoneNum = str
26+
TimeString = str
2627

27-
OneOf = Annotated[
28-
list[list[Union[str, list[str]]]],
29-
Field(
30-
description=["List fields where at least one is required for the data model to be valid."]
31-
),
32-
]
3328

34-
ConflictsWith = Annotated[
35-
list[list[str]],
36-
Field(
37-
description=[
38-
"List of pairs of fields where if one is present, the other cannot be present."
39-
]
40-
),
41-
]
29+
# Standalone validator for timespan strings
30+
def validate_timespan_string(value: Any) -> list[str]:
31+
"""Validate that value is a list of exactly 2 time strings in HH:MM or HH:MM:SS format.
4232
43-
AnyOf = Annotated[
44-
list[list[Union[str, list[str]]]],
45-
Field(description=["List fields where any are required for the data model to be valid."]),
46-
]
33+
Returns the value if valid, raises ValueError otherwise.
34+
"""
35+
if not isinstance(value, list):
36+
msg = "TimespanString must be a list"
37+
raise ValueError(msg)
38+
REQUIRED_LENGTH = 2
39+
if len(value) != REQUIRED_LENGTH:
40+
msg = f"TimespanString must have exactly {REQUIRED_LENGTH} elements"
41+
raise ValueError(msg)
42+
for item in value:
43+
if not isinstance(item, str):
44+
msg = "TimespanString elements must be strings"
45+
raise ValueError(msg)
46+
import re # noqa: PLC0415
4747

48-
Latitude = Annotated[float, Field(ge=-90, le=90, description="Latitude of stop.")]
48+
if not re.match(r"^(\d+):([0-5]\d)(:[0-5]\d)?$", item):
49+
msg = f"Invalid time format: {item}"
50+
raise ValueError(msg)
51+
return value
4952

50-
Longitude = Annotated[float, Field(ge=-180, le=180, description="Longitude of stop.")]
5153

52-
PhoneNum = Annotated[str, Field("", description="Phone number for the specified location.")]
53-
TimeString = Annotated[
54-
str,
55-
Field(
56-
description="A time string in the format HH:MM or HH:MM:SS",
57-
pattern=r"^(\d+):([0-5]\d)(:[0-5]\d)?$",
58-
),
59-
]
60-
TimespanString = Annotated[
61-
list[TimeString],
62-
Field(min_length=2, max_length=2),
63-
]
54+
TimespanString = list[str]
6455
TimeType = Union[time, str, int]

network_wrangler/models/gtfs/table_types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pandas as pd
88
import pandera as pa
9+
from pandera.dtypes import DataType
910
from pandera.engines import pandas_engine
1011

1112

@@ -18,7 +19,7 @@ class HttpURL(pandas_engine.NpString):
1819

1920
def check(
2021
self,
21-
pandera_dtype: pa.dtypes.DataType,
22+
pandera_dtype: DataType,
2223
data_container: pd.Series,
2324
) -> Union[bool, Iterable[bool]]:
2425
"""Check if the data is a valid HTTP URL."""

0 commit comments

Comments
 (0)