Skip to content

Commit 6e65b10

Browse files
authored
fix: analysis report when there is a column with mixed data types (#131)
* column config -> column name when possible * fallback to dtype of first non-null element * add unit tests * add error message info to warning * catch str_ too
1 parent ebc4024 commit 6e65b10

File tree

6 files changed

+153
-87
lines changed

6 files changed

+153
-87
lines changed

src/data_designer/engine/analysis/column_profilers/base.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from abc import ABC, abstractmethod
88

99
import pandas as pd
10-
import pyarrow as pa
1110
from pydantic import BaseModel, model_validator
1211
from typing_extensions import Self
1312

@@ -29,12 +28,6 @@ def validate_column_exists(self) -> Self:
2928
raise ValueError(f"Column {self.column_config.name!r} not found in DataFrame")
3029
return self
3130

32-
@model_validator(mode="after")
33-
def ensure_pyarrow_backend(self) -> Self:
34-
if not all(isinstance(dtype, pd.ArrowDtype) for dtype in self.df.dtypes):
35-
self.df = pa.Table.from_pandas(self.df).to_pandas(types_mapper=pd.ArrowDtype)
36-
return self
37-
3831
def as_tuple(self) -> tuple[SingleColumnConfig, pd.DataFrame]:
3932
return (self.column_config, self.df)
4033

src/data_designer/engine/analysis/column_statistics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def calculate(self) -> Self:
5959
)
6060

6161
def calculate_general_column_info(self) -> dict[str, Any]:
62-
return calculate_general_column_info(self.column_config, self.df)
62+
return calculate_general_column_info(self.column_config.name, self.df)
6363

6464
def __repr__(self) -> str:
6565
params = []
@@ -93,7 +93,7 @@ def calculate_sampler_distribution(self) -> dict[str, Any]:
9393
return (
9494
{
9595
"sampler_type": SamplerType(self.column_config.sampler_type),
96-
**calculate_column_distribution(self.column_config, self.df, dist_type),
96+
**calculate_column_distribution(self.column_config.name, self.df, dist_type),
9797
}
9898
if make_dist
9999
else {
@@ -109,7 +109,7 @@ class SeedDatasetColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
109109

110110
class ValidationColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
111111
def calculate_validation_column_info(self) -> dict[str, Any]:
112-
return calculate_validation_column_info(self.column_config, self.df)
112+
return calculate_validation_column_info(self.column_config.name, self.df)
113113

114114

115115
class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ...

src/data_designer/engine/analysis/dataset_profiler.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from functools import cached_property
77

88
import pandas as pd
9+
import pyarrow as pa
910
from pydantic import Field, field_validator
1011

1112
from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT
@@ -19,10 +20,8 @@
1920
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
2021
from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
2122
from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
22-
from data_designer.engine.dataset_builders.multi_column_configs import (
23-
DatasetBuilderColumnConfigT,
24-
MultiColumnConfig,
25-
)
23+
from data_designer.engine.analysis.utils.column_statistics_calculations import has_pyarrow_backend
24+
from data_designer.engine.dataset_builders.multi_column_configs import DatasetBuilderColumnConfigT, MultiColumnConfig
2625
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
2726
from data_designer.engine.resources.resource_provider import ResourceProvider
2827

@@ -68,6 +67,7 @@ def profile_dataset(
6867
logger.info("📐 Measuring dataset column statistics:")
6968

7069
self._validate_schema_consistency(list(dataset.columns))
70+
dataset = self._convert_to_pyarrow_backend_if_needed(dataset)
7171

7272
column_statistics = []
7373
for c in self.config.column_configs:
@@ -100,6 +100,27 @@ def profile_dataset(
100100
column_profiles=column_profiles if column_profiles else None,
101101
)
102102

103+
def _convert_to_pyarrow_backend_if_needed(self, dataset: pd.DataFrame) -> pd.DataFrame:
104+
if not has_pyarrow_backend(dataset):
105+
try:
106+
dataset = pa.Table.from_pandas(dataset).to_pandas(types_mapper=pd.ArrowDtype)
107+
except Exception as e:
108+
# For ArrowTypeError, the second arg contains the more informative message
109+
if isinstance(e, pa.lib.ArrowTypeError) and len(e.args) > 1:
110+
error_msg = str(e.args[1])
111+
else:
112+
error_msg = str(e)
113+
for col in dataset.columns:
114+
# Make sure column names are clear in the error message
115+
error_msg = error_msg.replace(col, f"'{col}'")
116+
logger.warning("⚠️ Unable to convert the dataset to a PyArrow backend")
117+
logger.warning(f" |-- Conversion Error Message: {error_msg}")
118+
logger.warning(" |-- This is often due to at least one column having mixed data types")
119+
logger.warning(
120+
" |-- Note: Reported data types will be inferred from the first non-null value of each column"
121+
)
122+
return dataset
123+
103124
def _create_column_profiler(self, profiler_config: ColumnProfilerConfigT) -> ColumnProfiler:
104125
return self.registry.column_profilers.get_for_config_type(type(profiler_config))(
105126
config=profiler_config, resource_provider=self.resource_provider

src/data_designer/engine/analysis/utils/column_statistics_calculations.py

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818
MissingValue,
1919
NumericalDistribution,
2020
)
21-
from data_designer.config.column_configs import (
22-
LLMTextColumnConfig,
23-
SingleColumnConfig,
24-
ValidationColumnConfig,
25-
)
21+
from data_designer.config.column_configs import LLMTextColumnConfig
2622
from data_designer.engine.column_generators.generators.llm_generators import (
2723
PromptType,
2824
RecordBasedPromptRenderer,
@@ -39,41 +35,54 @@
3935

4036

4137
def calculate_column_distribution(
42-
column_config: SingleColumnConfig, df: pd.DataFrame, distribution_type: ColumnDistributionType
38+
column_name: str, df: pd.DataFrame, distribution_type: ColumnDistributionType
4339
) -> dict[str, CategoricalDistribution | NumericalDistribution | MissingValue | None]:
4440
distribution_type = ColumnDistributionType(distribution_type)
4541
try:
4642
if distribution_type == ColumnDistributionType.CATEGORICAL:
4743
return {
4844
"distribution_type": ColumnDistributionType.CATEGORICAL,
49-
"distribution": CategoricalDistribution.from_series(df[column_config.name]),
45+
"distribution": CategoricalDistribution.from_series(df[column_name]),
5046
}
5147

5248
if distribution_type == ColumnDistributionType.NUMERICAL:
5349
return {
5450
"distribution_type": ColumnDistributionType.NUMERICAL,
55-
"distribution": NumericalDistribution.from_series(df[column_config.name]),
51+
"distribution": NumericalDistribution.from_series(df[column_name]),
5652
}
5753
except Exception as e:
58-
logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_config.name}' {e}")
54+
logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_name}' {e}")
5955
return {
6056
"distribution_type": ColumnDistributionType.UNKNOWN,
6157
"distribution": MissingValue.CALCULATION_FAILED,
6258
}
6359

6460

65-
def calculate_general_column_info(column_config: SingleColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
61+
def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
6662
try:
67-
_df = pd.DataFrame(df[column_config.name].apply(ensure_hashable))
63+
_df = pd.DataFrame(df[column_name].apply(ensure_hashable))
64+
65+
if has_pyarrow_backend(df):
66+
pyarrow_dtype = str(df[column_name].dtype.pyarrow_dtype)
67+
simple_dtype = convert_pyarrow_dtype_to_simple_dtype(df[column_name].dtype.pyarrow_dtype)
68+
else:
69+
# We do not log a warning at the column-level because it would be too noisy.
70+
# However, there is a logged warning at the dataset-profiler level.
71+
try:
72+
simple_dtype = get_column_data_type_from_first_non_null_value(column_name, df)
73+
except Exception:
74+
simple_dtype = MissingValue.CALCULATION_FAILED
75+
pyarrow_dtype = "n/a"
76+
6877
return {
69-
"pyarrow_dtype": str(df[column_config.name].dtype.pyarrow_dtype),
70-
"simple_dtype": convert_pyarrow_dtype_to_simple_dtype(df[column_config.name].dtype.pyarrow_dtype),
71-
"num_records": len(_df[column_config.name]),
72-
"num_null": _df[column_config.name].isnull().sum(),
73-
"num_unique": _df[column_config.name].nunique(),
78+
"pyarrow_dtype": pyarrow_dtype,
79+
"simple_dtype": simple_dtype,
80+
"num_records": len(_df[column_name]),
81+
"num_null": _df[column_name].isnull().sum(),
82+
"num_unique": _df[column_name].nunique(),
7483
}
7584
except Exception as e:
76-
logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_config.name}': {e}")
85+
logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_name}': {e}")
7786
return {
7887
"pyarrow_dtype": MissingValue.CALCULATION_FAILED,
7988
"simple_dtype": MissingValue.CALCULATION_FAILED,
@@ -115,11 +124,9 @@ def calculate_prompt_token_stats(
115124
}
116125

117126

118-
def calculate_completion_token_stats(
119-
column_config: LLMTextColumnConfig, df: pd.DataFrame
120-
) -> dict[str, float | MissingValue]:
127+
def calculate_completion_token_stats(column_name: str, df: pd.DataFrame) -> dict[str, float | MissingValue]:
121128
try:
122-
tokens_per_record = df[column_config.name].apply(
129+
tokens_per_record = df[column_name].apply(
123130
lambda value: len(TOKENIZER.encode(str(value), disallowed_special=()))
124131
)
125132
return {
@@ -128,9 +135,7 @@ def calculate_completion_token_stats(
128135
"completion_tokens_stddev": tokens_per_record.std(),
129136
}
130137
except Exception as e:
131-
logger.warning(
132-
f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_config.name}: {e}"
133-
)
138+
logger.warning(f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_name}: {e}")
134139
return {
135140
"completion_tokens_mean": MissingValue.CALCULATION_FAILED,
136141
"completion_tokens_median": MissingValue.CALCULATION_FAILED,
@@ -141,16 +146,16 @@ def calculate_completion_token_stats(
141146
def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
142147
return {
143148
**calculate_prompt_token_stats(column_config, df),
144-
**calculate_completion_token_stats(column_config, df),
149+
**calculate_completion_token_stats(column_config.name, df),
145150
}
146151

147152

148-
def calculate_validation_column_info(column_config: ValidationColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
153+
def calculate_validation_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
149154
try:
150-
return {"num_valid_records": df[column_config.name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
155+
return {"num_valid_records": df[column_name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
151156
except Exception as e:
152157
logger.warning(
153-
f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_config.name}: {e}"
158+
f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_name}: {e}"
154159
)
155160
return {"num_valid_records": MissingValue.CALCULATION_FAILED}
156161

@@ -160,22 +165,33 @@ def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
160165
return f"list[{convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype.value_type)}]"
161166
if isinstance(pyarrow_dtype, pa.StructType):
162167
return "dict"
163-
pyarrow_dtype_str = str(pyarrow_dtype)
164-
if "int" in pyarrow_dtype_str:
168+
return convert_to_simple_dtype(str(pyarrow_dtype))
169+
170+
171+
def convert_to_simple_dtype(dtype: str) -> str:
172+
if "int" in dtype:
165173
return "int"
166-
if "double" in pyarrow_dtype_str:
174+
if "double" in dtype:
167175
return "float"
168-
if "float" in pyarrow_dtype_str:
176+
if "float" in dtype:
169177
return "float"
170-
if "string" in pyarrow_dtype_str:
178+
if "str" in dtype:
171179
return "string"
172-
if "timestamp" in pyarrow_dtype_str:
180+
if "timestamp" in dtype:
173181
return "timestamp"
174-
if "time" in pyarrow_dtype_str:
182+
if "time" in dtype:
175183
return "time"
176-
if "date" in pyarrow_dtype_str:
184+
if "date" in dtype:
177185
return "date"
178-
return pyarrow_dtype_str
186+
return dtype
187+
188+
189+
def get_column_data_type_from_first_non_null_value(column_name: str, df: pd.DataFrame) -> str:
190+
df_no_nulls = df[column_name].dropna()
191+
if len(df_no_nulls) == 0:
192+
return MissingValue.CALCULATION_FAILED
193+
dtype = type(df_no_nulls.iloc[0]).__name__
194+
return convert_to_simple_dtype(dtype)
179195

180196

181197
def ensure_hashable(x: Any) -> str:
@@ -207,3 +223,7 @@ def ensure_boolean(v: bool | str | int | None) -> bool:
207223
if v is None:
208224
return False
209225
raise ValueError(f"Invalid boolean value: {v}")
226+
227+
228+
def has_pyarrow_backend(df: pd.DataFrame) -> bool:
229+
return all(isinstance(dtype, pd.ArrowDtype) for dtype in df.dtypes)

tests/engine/analysis/column_profilers/test_base.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,6 @@ def test_column_config_with_dataframe_column_not_found_validation_error():
3737
ColumnConfigWithDataFrame(column_config=column_config, df=df)
3838

3939

40-
def test_column_config_with_dataframe_pyarrow_backend_conversion():
41-
df = pd.DataFrame({"test_column": [1, 2, 3]})
42-
column_config = SamplerColumnConfig(
43-
name="test_column", sampler_type=SamplerType.CATEGORY, params={"values": [1, 2, 3]}
44-
)
45-
46-
config_with_df = ColumnConfigWithDataFrame(column_config=column_config, df=df)
47-
48-
assert all(isinstance(dtype, pd.ArrowDtype) for dtype in config_with_df.df.dtypes)
49-
50-
5140
def test_column_config_with_dataframe_as_tuple_method():
5241
df = pd.DataFrame({"test_column": [1, 2, 3]})
5342
column_config = SamplerColumnConfig(

0 commit comments

Comments
 (0)