Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/data_designer/engine/analysis/column_profilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from abc import ABC, abstractmethod

import pandas as pd
import pyarrow as pa
from pydantic import BaseModel, model_validator
from typing_extensions import Self

Expand All @@ -29,12 +28,6 @@ def validate_column_exists(self) -> Self:
raise ValueError(f"Column {self.column_config.name!r} not found in DataFrame")
return self

@model_validator(mode="after")
def ensure_pyarrow_backend(self) -> Self:
if not all(isinstance(dtype, pd.ArrowDtype) for dtype in self.df.dtypes):
self.df = pa.Table.from_pandas(self.df).to_pandas(types_mapper=pd.ArrowDtype)
return self

Comment on lines -32 to -37
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this validator is what was causing generation jobs to fail at the profiling step

def as_tuple(self) -> tuple[SingleColumnConfig, pd.DataFrame]:
return (self.column_config, self.df)

Expand Down
6 changes: 3 additions & 3 deletions src/data_designer/engine/analysis/column_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def calculate(self) -> Self:
)

def calculate_general_column_info(self) -> dict[str, Any]:
return calculate_general_column_info(self.column_config, self.df)
return calculate_general_column_info(self.column_config.name, self.df)

def __repr__(self) -> str:
params = []
Expand Down Expand Up @@ -93,7 +93,7 @@ def calculate_sampler_distribution(self) -> dict[str, Any]:
return (
{
"sampler_type": SamplerType(self.column_config.sampler_type),
**calculate_column_distribution(self.column_config, self.df, dist_type),
**calculate_column_distribution(self.column_config.name, self.df, dist_type),
}
if make_dist
else {
Expand All @@ -109,7 +109,7 @@ class SeedDatasetColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):

class ValidationColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
def calculate_validation_column_info(self) -> dict[str, Any]:
return calculate_validation_column_info(self.column_config, self.df)
return calculate_validation_column_info(self.column_config.name, self.df)


class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ...
Expand Down
20 changes: 16 additions & 4 deletions src/data_designer/engine/analysis/dataset_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import cached_property

import pandas as pd
import pyarrow as pa
from pydantic import Field, field_validator

from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT
Expand All @@ -19,10 +20,8 @@
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
from data_designer.engine.dataset_builders.multi_column_configs import (
DatasetBuilderColumnConfigT,
MultiColumnConfig,
)
from data_designer.engine.analysis.utils.column_statistics_calculations import has_pyarrow_backend
from data_designer.engine.dataset_builders.multi_column_configs import DatasetBuilderColumnConfigT, MultiColumnConfig
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
from data_designer.engine.resources.resource_provider import ResourceProvider

Expand Down Expand Up @@ -68,6 +67,7 @@ def profile_dataset(
logger.info("📐 Measuring dataset column statistics:")

self._validate_schema_consistency(list(dataset.columns))
dataset = self._convert_to_pyarrow_backend_if_needed(dataset)

column_statistics = []
for c in self.config.column_configs:
Expand Down Expand Up @@ -100,6 +100,18 @@ def profile_dataset(
column_profiles=column_profiles if column_profiles else None,
)

def _convert_to_pyarrow_backend_if_needed(self, dataset: pd.DataFrame) -> pd.DataFrame:
if not has_pyarrow_backend(dataset):
try:
dataset = pa.Table.from_pandas(dataset).to_pandas(types_mapper=pd.ArrowDtype)
except Exception:
logger.warning(
"⚠️ Unable to convert the dataset to a PyArrow backend. This is often due to at least "
"one column having mixed data types. As a result, the reported data types "
"will be inferred from the type of the first non-null value of each column."
)
Comment on lines 119 to 121
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the clue to the user that their dataset has at least one column with mixed data types

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not blocking this needed bugfix: but is it possible to notify the user about which column generated the issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question, probably can grab it from the exception message. let me check that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new message looks like this

[17:19:19] [WARNING] ⚠️ Unable to convert the dataset to a PyArrow backend
[17:19:19] [WARNING]   |-- Conversion Error Message: Conversion failed for column 'nano_response' with type object
[17:19:19] [WARNING]   |-- This is often due to at least one column having mixed data types
[17:19:19] [WARNING]   |-- Note: Reported data types will be inferred from the first non-null value of each column

return dataset

def _create_column_profiler(self, profiler_config: ColumnProfilerConfigT) -> ColumnProfiler:
return self.registry.column_profilers.get_for_config_type(type(profiler_config))(
config=profiler_config, resource_provider=self.resource_provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
MissingValue,
NumericalDistribution,
)
from data_designer.config.column_configs import (
LLMTextColumnConfig,
SingleColumnConfig,
ValidationColumnConfig,
)
from data_designer.config.column_configs import LLMTextColumnConfig
from data_designer.engine.column_generators.generators.llm_generators import (
PromptType,
RecordBasedPromptRenderer,
Expand All @@ -39,41 +35,54 @@


def calculate_column_distribution(
column_config: SingleColumnConfig, df: pd.DataFrame, distribution_type: ColumnDistributionType
column_name: str, df: pd.DataFrame, distribution_type: ColumnDistributionType
) -> dict[str, CategoricalDistribution | NumericalDistribution | MissingValue | None]:
distribution_type = ColumnDistributionType(distribution_type)
try:
if distribution_type == ColumnDistributionType.CATEGORICAL:
return {
"distribution_type": ColumnDistributionType.CATEGORICAL,
"distribution": CategoricalDistribution.from_series(df[column_config.name]),
"distribution": CategoricalDistribution.from_series(df[column_name]),
}

if distribution_type == ColumnDistributionType.NUMERICAL:
return {
"distribution_type": ColumnDistributionType.NUMERICAL,
"distribution": NumericalDistribution.from_series(df[column_config.name]),
"distribution": NumericalDistribution.from_series(df[column_name]),
}
except Exception as e:
logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_config.name}' {e}")
logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_name}' {e}")
return {
"distribution_type": ColumnDistributionType.UNKNOWN,
"distribution": MissingValue.CALCULATION_FAILED,
}


def calculate_general_column_info(column_config: SingleColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
try:
_df = pd.DataFrame(df[column_config.name].apply(ensure_hashable))
_df = pd.DataFrame(df[column_name].apply(ensure_hashable))

if has_pyarrow_backend(df):
pyarrow_dtype = str(df[column_name].dtype.pyarrow_dtype)
simple_dtype = convert_pyarrow_dtype_to_simple_dtype(df[column_name].dtype.pyarrow_dtype)
else:
# We do not log a warning at the column-level because it would be too noisy.
# However, there is a logged warning at the dataset-profiler level.
try:
simple_dtype = get_column_data_type_from_first_non_null_value(column_name, df)
except Exception:
simple_dtype = MissingValue.CALCULATION_FAILED
pyarrow_dtype = "n/a"

return {
"pyarrow_dtype": str(df[column_config.name].dtype.pyarrow_dtype),
"simple_dtype": convert_pyarrow_dtype_to_simple_dtype(df[column_config.name].dtype.pyarrow_dtype),
"num_records": len(_df[column_config.name]),
"num_null": _df[column_config.name].isnull().sum(),
"num_unique": _df[column_config.name].nunique(),
"pyarrow_dtype": pyarrow_dtype,
"simple_dtype": simple_dtype,
"num_records": len(_df[column_name]),
"num_null": _df[column_name].isnull().sum(),
"num_unique": _df[column_name].nunique(),
}
except Exception as e:
logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_config.name}': {e}")
logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_name}': {e}")
return {
"pyarrow_dtype": MissingValue.CALCULATION_FAILED,
"simple_dtype": MissingValue.CALCULATION_FAILED,
Expand Down Expand Up @@ -115,11 +124,9 @@ def calculate_prompt_token_stats(
}


def calculate_completion_token_stats(
column_config: LLMTextColumnConfig, df: pd.DataFrame
) -> dict[str, float | MissingValue]:
def calculate_completion_token_stats(column_name: str, df: pd.DataFrame) -> dict[str, float | MissingValue]:
try:
tokens_per_record = df[column_config.name].apply(
tokens_per_record = df[column_name].apply(
lambda value: len(TOKENIZER.encode(str(value), disallowed_special=()))
)
return {
Expand All @@ -128,9 +135,7 @@ def calculate_completion_token_stats(
"completion_tokens_stddev": tokens_per_record.std(),
}
except Exception as e:
logger.warning(
f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_config.name}: {e}"
)
logger.warning(f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_name}: {e}")
return {
"completion_tokens_mean": MissingValue.CALCULATION_FAILED,
"completion_tokens_median": MissingValue.CALCULATION_FAILED,
Expand All @@ -141,16 +146,16 @@ def calculate_completion_token_stats(
def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
return {
**calculate_prompt_token_stats(column_config, df),
**calculate_completion_token_stats(column_config, df),
**calculate_completion_token_stats(column_config.name, df),
}


def calculate_validation_column_info(column_config: ValidationColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
def calculate_validation_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
try:
return {"num_valid_records": df[column_config.name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
return {"num_valid_records": df[column_name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
except Exception as e:
logger.warning(
f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_config.name}: {e}"
f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_name}: {e}"
)
return {"num_valid_records": MissingValue.CALCULATION_FAILED}

Expand All @@ -160,22 +165,33 @@ def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
return f"list[{convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype.value_type)}]"
if isinstance(pyarrow_dtype, pa.StructType):
return "dict"
pyarrow_dtype_str = str(pyarrow_dtype)
if "int" in pyarrow_dtype_str:
return convert_to_simple_dtype(str(pyarrow_dtype))


def convert_to_simple_dtype(dtype: str) -> str:
if "int" in dtype:
return "int"
if "double" in pyarrow_dtype_str:
if "double" in dtype:
return "float"
if "float" in pyarrow_dtype_str:
if "float" in dtype:
return "float"
if "string" in pyarrow_dtype_str:
if "string" in dtype or dtype == "str":
return "string"
if "timestamp" in pyarrow_dtype_str:
if "timestamp" in dtype:
return "timestamp"
if "time" in pyarrow_dtype_str:
if "time" in dtype:
return "time"
if "date" in pyarrow_dtype_str:
if "date" in dtype:
return "date"
return pyarrow_dtype_str
return dtype


def get_column_data_type_from_first_non_null_value(column_name: str, df: pd.DataFrame) -> str:
df_no_nulls = df[column_name].dropna()
if len(df_no_nulls) == 0:
return MissingValue.CALCULATION_FAILED
dtype = type(df_no_nulls.iloc[0]).__name__
return convert_to_simple_dtype(dtype)


def ensure_hashable(x: Any) -> str:
Expand Down Expand Up @@ -207,3 +223,7 @@ def ensure_boolean(v: bool | str | int | None) -> bool:
if v is None:
return False
raise ValueError(f"Invalid boolean value: {v}")


def has_pyarrow_backend(df: pd.DataFrame) -> bool:
return all(isinstance(dtype, pd.ArrowDtype) for dtype in df.dtypes)
11 changes: 0 additions & 11 deletions tests/engine/analysis/column_profilers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,6 @@ def test_column_config_with_dataframe_column_not_found_validation_error():
ColumnConfigWithDataFrame(column_config=column_config, df=df)


def test_column_config_with_dataframe_pyarrow_backend_conversion():
df = pd.DataFrame({"test_column": [1, 2, 3]})
column_config = SamplerColumnConfig(
name="test_column", sampler_type=SamplerType.CATEGORY, params={"values": [1, 2, 3]}
)

config_with_df = ColumnConfigWithDataFrame(column_config=column_config, df=df)

assert all(isinstance(dtype, pd.ArrowDtype) for dtype in config_with_df.df.dtypes)


def test_column_config_with_dataframe_as_tuple_method():
df = pd.DataFrame({"test_column": [1, 2, 3]})
column_config = SamplerColumnConfig(
Expand Down
Loading