Skip to content

Commit 9d44b3d

Browse files
committed
fix: Multiple fixes for handling different data types in pandas column analysis
1 parent 67a97b4 commit 9d44b3d

File tree

7 files changed

+593
-41
lines changed

7 files changed

+593
-41
lines changed

.cursorrules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ Additional for integration tests:
9494
# Run local tests
9595
./bin/test-local
9696

97+
# Run a specific test file
98+
./bin/test-local tests/unit/test_file.py
99+
100+
# ... or specific test from file
101+
./bin/test-local tests/unit/test_file.py::TestClass::test_method
102+
97103
# Run specific test type
98104
export TEST_TYPE="unit|integration"
99105
export TOOLKIT_VERSION="local-build"

deepnote_toolkit/ocelots/pandas/analyze.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import pandas as pd
77

88
from deepnote_toolkit.ocelots.constants import DEEPNOTE_INDEX_COLUMN
9+
from deepnote_toolkit.ocelots.pandas.utils import (
10+
is_type_datetime_or_timedelta,
11+
is_type_numeric,
12+
safe_convert_to_string,
13+
)
914
from deepnote_toolkit.ocelots.types import ColumnsStatsRecord, ColumnStats
1015

1116

@@ -24,7 +29,10 @@ def _get_categories(np_array):
2429
# special treatment for empty values
2530
num_nans = pandas_series.isna().sum().item()
2631

27-
counter = Counter(pandas_series.dropna().astype(str))
32+
try:
33+
counter = Counter(pandas_series.dropna().astype(str))
34+
except (TypeError, UnicodeDecodeError, AttributeError):
35+
counter = Counter(pandas_series.dropna().apply(safe_convert_to_string))
2836

2937
max_items = 3
3038
if num_nans > 0:
@@ -46,34 +54,12 @@ def _get_categories(np_array):
4654
return [{"name": name, "count": count} for name, count in categories]
4755

4856

49-
def _is_type_numeric(dtype):
50-
"""
51-
Returns True if dtype is numeric, False otherwise
52-
53-
Numeric means either a number (int, float, complex) or a datetime or timedelta.
54-
It means e.g. that a range of these values can be plotted on a histogram.
55-
"""
56-
57-
# datetime doesn't play nice with np.issubdtype, so we need to check explicitly
58-
if pd.api.types.is_datetime64_any_dtype(dtype) or pd.api.types.is_timedelta64_dtype(
59-
dtype
60-
):
61-
return True
62-
63-
try:
64-
return np.issubdtype(dtype, np.number)
65-
except TypeError:
66-
# np.issubdtype crashes on categorical column dtype, and also on others, e.g. geopandas types
67-
return False
68-
69-
7057
def _get_histogram(pd_series):
7158
try:
72-
if pd.api.types.is_datetime64_any_dtype(
73-
pd_series
74-
) or pd.api.types.is_timedelta64_dtype(pd_series):
75-
# convert datetime or timedelta to an integer so that a histogram can be created
59+
if is_type_datetime_or_timedelta(pd_series):
7660
np_array = np.array(pd_series.dropna().astype(int))
61+
elif np.issubdtype(pd_series.dtype, np.complexfloating):
62+
return None
7763
else:
7864
# let's drop infinite values because they break histograms
7965
np_array = np.array(pd_series.replace([np.inf, -np.inf], np.nan).dropna())
@@ -104,11 +90,22 @@ def _calculate_min_max(column):
10490
"""
10591
Calculate min and max values for a given column.
10692
"""
107-
if _is_type_numeric(column.dtype):
93+
if not is_type_numeric(column.dtype):
94+
return None, None
95+
96+
# Complex numbers cannot be compared for min/max
97+
# Check for datetime/timedelta types before because np.issubdtype doesn't work reliably on them
98+
if not is_type_datetime_or_timedelta(column) and np.issubdtype(
99+
column.dtype, np.complexfloating
100+
):
101+
return None, None
102+
103+
try:
108104
min_value = str(min(column.dropna())) if len(column.dropna()) > 0 else None
109105
max_value = str(max(column.dropna())) if len(column.dropna()) > 0 else None
110106
return min_value, max_value
111-
return None, None
107+
except (TypeError, ValueError):
108+
return None, None
112109

113110

114111
def analyze_columns(
@@ -167,7 +164,7 @@ def analyze_columns(
167164
unique_count=_count_unique(column), nan_count=column.isnull().sum().item()
168165
)
169166

170-
if _is_type_numeric(column.dtype):
167+
if is_type_numeric(column.dtype):
171168
min_value, max_value = _calculate_min_max(column)
172169
columns[i].stats.min = min_value
173170
columns[i].stats.max = max_value
@@ -187,7 +184,7 @@ def analyze_columns(
187184
for i in range(max_columns_to_analyze, len(df.columns)):
188185
# Ignore columns that are not numeric
189186
column = df.iloc[:, i]
190-
if not _is_type_numeric(column.dtype):
187+
if not is_type_numeric(column.dtype):
191188
continue
192189

193190
column_name = columns[i].name

deepnote_toolkit/ocelots/pandas/utils.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
1+
import base64
2+
13
import numpy as np
24
import pandas as pd
35
from packaging.requirements import Requirement
46

57
from deepnote_toolkit.ocelots.constants import MAX_STRING_CELL_LENGTH
68

79

10+
def safe_convert_to_string(value):
11+
if isinstance(value, bytes):
12+
return base64.b64encode(value).decode("ascii")
13+
try:
14+
return str(value)
15+
except Exception:
16+
return "<unconvertible>"
17+
18+
819
# like fillna, but only fills NaT (not a time) values in datetime columns with the specified value
920
def fill_nat(df, value):
1021
df_datetime_columns = df.select_dtypes(
@@ -76,33 +87,38 @@ def deduplicate_columns(df):
7687
# Cast dataframe contents to strings and trim them to avoid sending too much data
7788
def cast_objects_to_string(df):
7889
def to_string_truncated(elem):
79-
elem_string = str(elem)
90+
elem_string = safe_convert_to_string(elem)
8091
return (
8192
(elem_string[: MAX_STRING_CELL_LENGTH - 1] + "…")
8293
if len(elem_string) > MAX_STRING_CELL_LENGTH
8394
else elem_string
8495
)
8596

8697
for column in df:
87-
if not _is_type_number(df[column].dtype):
98+
if not is_type_numeric(df[column].dtype):
8899
# if the dtype is not a number, we want to convert it to string and truncate
89100
df[column] = df[column].apply(to_string_truncated)
90101

91102
return df
92103

93104

94-
def _is_type_number(dtype):
105+
def is_type_datetime_or_timedelta(series_or_dtype):
95106
"""
96-
Returns True if dtype is a number, False otherwise. Datetime and timedelta will return False.
107+
Returns True if the series or dtype is datetime or timedelta, False otherwise.
108+
"""
109+
return pd.api.types.is_datetime64_any_dtype(
110+
series_or_dtype
111+
) or pd.api.types.is_timedelta64_dtype(series_or_dtype)
112+
97113

98-
The primary intent of this is to recognize a value that will converted to a JSON number during serialization.
114+
def is_type_numeric(dtype):
99115
"""
116+
Returns True if dtype is numeric, False otherwise
100117
101-
if pd.api.types.is_datetime64_any_dtype(dtype) or pd.api.types.is_timedelta64_dtype(
102-
dtype
103-
):
104-
# np.issubdtype(dtype, np.number) returns True for timedelta, which we don't want
105-
return False
118+
Numeric means either a number (int, float, complex) or a datetime or timedelta.
119+
"""
120+
if is_type_datetime_or_timedelta(dtype):
121+
return True
106122

107123
try:
108124
return np.issubdtype(dtype, np.number)

deepnote_toolkit/ocelots/pyspark/implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def select_column(field: StructField) -> Column:
243243
# We slice binary field before encoding to avoid encoding potentially big blob. Round slicing to
244244
# 4 bytes to avoid breaking multi-byte sequences
245245
if isinstance(field.dataType, BinaryType):
246-
sliced = F.substring(field, 1, keep_bytes)
246+
sliced = F.substring(F.col(field.name), 1, keep_bytes)
247247
return F.base64(sliced)
248248

249249
# String just needs to be trimmed

tests/unit/helpers/testing_dataframes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,14 @@ def create_dataframe_with_duplicate_column_names():
261261
datetime.datetime(2023, 1, 1, 12, 0, 0),
262262
datetime.datetime(2023, 1, 2, 12, 0, 0),
263263
],
264+
"binary": [b"hello", b"world"],
264265
}
265266
),
266267
"pyspark_schema": pst.StructType(
267268
[
268269
pst.StructField("list", pst.ArrayType(pst.IntegerType()), True),
269270
pst.StructField("datetime", pst.TimestampType(), True),
271+
pst.StructField("binary", pst.BinaryType(), True),
270272
]
271273
),
272274
},

0 commit comments

Comments
 (0)