Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/data_designer/engine/analysis/column_profilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from abc import ABC, abstractmethod

import pandas as pd
import pyarrow as pa
from pydantic import BaseModel, model_validator
from typing_extensions import Self

Expand All @@ -29,12 +28,6 @@ def validate_column_exists(self) -> Self:
raise ValueError(f"Column {self.column_config.name!r} not found in DataFrame")
return self

@model_validator(mode="after")
def ensure_pyarrow_backend(self) -> Self:
if not all(isinstance(dtype, pd.ArrowDtype) for dtype in self.df.dtypes):
self.df = pa.Table.from_pandas(self.df).to_pandas(types_mapper=pd.ArrowDtype)
return self

Comment on lines -32 to -37
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this validator is what was causing generation jobs to fail at the profiling step

def as_tuple(self) -> tuple[SingleColumnConfig, pd.DataFrame]:
return (self.column_config, self.df)

Expand Down
6 changes: 3 additions & 3 deletions src/data_designer/engine/analysis/column_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def calculate(self) -> Self:
)

def calculate_general_column_info(self) -> dict[str, Any]:
return calculate_general_column_info(self.column_config, self.df)
return calculate_general_column_info(self.column_config.name, self.df)

def __repr__(self) -> str:
params = []
Expand Down Expand Up @@ -93,7 +93,7 @@ def calculate_sampler_distribution(self) -> dict[str, Any]:
return (
{
"sampler_type": SamplerType(self.column_config.sampler_type),
**calculate_column_distribution(self.column_config, self.df, dist_type),
**calculate_column_distribution(self.column_config.name, self.df, dist_type),
}
if make_dist
else {
Expand All @@ -109,7 +109,7 @@ class SeedDatasetColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):

class ValidationColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
def calculate_validation_column_info(self) -> dict[str, Any]:
return calculate_validation_column_info(self.column_config, self.df)
return calculate_validation_column_info(self.column_config.name, self.df)


class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ...
Expand Down
29 changes: 25 additions & 4 deletions src/data_designer/engine/analysis/dataset_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import cached_property

import pandas as pd
import pyarrow as pa
from pydantic import Field, field_validator

from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT
Expand All @@ -19,10 +20,8 @@
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
from data_designer.engine.dataset_builders.multi_column_configs import (
DatasetBuilderColumnConfigT,
MultiColumnConfig,
)
from data_designer.engine.analysis.utils.column_statistics_calculations import has_pyarrow_backend
from data_designer.engine.dataset_builders.multi_column_configs import DatasetBuilderColumnConfigT, MultiColumnConfig
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
from data_designer.engine.resources.resource_provider import ResourceProvider

Expand Down Expand Up @@ -68,6 +67,7 @@ def profile_dataset(
logger.info("📐 Measuring dataset column statistics:")

self._validate_schema_consistency(list(dataset.columns))
dataset = self._convert_to_pyarrow_backend_if_needed(dataset)

column_statistics = []
for c in self.config.column_configs:
Expand Down Expand Up @@ -100,6 +100,27 @@ def profile_dataset(
column_profiles=column_profiles if column_profiles else None,
)

def _convert_to_pyarrow_backend_if_needed(self, dataset: pd.DataFrame) -> pd.DataFrame:
if not has_pyarrow_backend(dataset):
try:
dataset = pa.Table.from_pandas(dataset).to_pandas(types_mapper=pd.ArrowDtype)
except Exception as e:
# For ArrowTypeError, the second arg contains the more informative message
if isinstance(e, pa.lib.ArrowTypeError) and len(e.args) > 1:
error_msg = str(e.args[1])
else:
error_msg = str(e)
for col in dataset.columns:
# Make sure column names are clear in the error message
error_msg = error_msg.replace(col, f"'{col}'")
logger.warning("⚠️ Unable to convert the dataset to a PyArrow backend")
logger.warning(f" |-- Conversion Error Message: {error_msg}")
logger.warning(" |-- This is often due to at least one column having mixed data types")
logger.warning(
" |-- Note: Reported data types will be inferred from the first non-null value of each column"
)
return dataset

def _create_column_profiler(self, profiler_config: ColumnProfilerConfigT) -> ColumnProfiler:
return self.registry.column_profilers.get_for_config_type(type(profiler_config))(
config=profiler_config, resource_provider=self.resource_provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
MissingValue,
NumericalDistribution,
)
from data_designer.config.column_configs import (
LLMTextColumnConfig,
SingleColumnConfig,
ValidationColumnConfig,
)
from data_designer.config.column_configs import LLMTextColumnConfig
from data_designer.engine.column_generators.generators.llm_generators import (
PromptType,
RecordBasedPromptRenderer,
Expand All @@ -39,41 +35,54 @@


def calculate_column_distribution(
column_config: SingleColumnConfig, df: pd.DataFrame, distribution_type: ColumnDistributionType
column_name: str, df: pd.DataFrame, distribution_type: ColumnDistributionType
) -> dict[str, CategoricalDistribution | NumericalDistribution | MissingValue | None]:
distribution_type = ColumnDistributionType(distribution_type)
try:
if distribution_type == ColumnDistributionType.CATEGORICAL:
return {
"distribution_type": ColumnDistributionType.CATEGORICAL,
"distribution": CategoricalDistribution.from_series(df[column_config.name]),
"distribution": CategoricalDistribution.from_series(df[column_name]),
}

if distribution_type == ColumnDistributionType.NUMERICAL:
return {
"distribution_type": ColumnDistributionType.NUMERICAL,
"distribution": NumericalDistribution.from_series(df[column_config.name]),
"distribution": NumericalDistribution.from_series(df[column_name]),
}
except Exception as e:
logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_config.name}' {e}")
logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_name}' {e}")
return {
"distribution_type": ColumnDistributionType.UNKNOWN,
"distribution": MissingValue.CALCULATION_FAILED,
}


def calculate_general_column_info(column_config: SingleColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
try:
_df = pd.DataFrame(df[column_config.name].apply(ensure_hashable))
_df = pd.DataFrame(df[column_name].apply(ensure_hashable))

if has_pyarrow_backend(df):
pyarrow_dtype = str(df[column_name].dtype.pyarrow_dtype)
simple_dtype = convert_pyarrow_dtype_to_simple_dtype(df[column_name].dtype.pyarrow_dtype)
else:
# We do not log a warning at the column-level because it would be too noisy.
# However, there is a logged warning at the dataset-profiler level.
try:
simple_dtype = get_column_data_type_from_first_non_null_value(column_name, df)
except Exception:
simple_dtype = MissingValue.CALCULATION_FAILED
pyarrow_dtype = "n/a"

return {
"pyarrow_dtype": str(df[column_config.name].dtype.pyarrow_dtype),
"simple_dtype": convert_pyarrow_dtype_to_simple_dtype(df[column_config.name].dtype.pyarrow_dtype),
"num_records": len(_df[column_config.name]),
"num_null": _df[column_config.name].isnull().sum(),
"num_unique": _df[column_config.name].nunique(),
"pyarrow_dtype": pyarrow_dtype,
"simple_dtype": simple_dtype,
"num_records": len(_df[column_name]),
"num_null": _df[column_name].isnull().sum(),
"num_unique": _df[column_name].nunique(),
}
except Exception as e:
logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_config.name}': {e}")
logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_name}': {e}")
return {
"pyarrow_dtype": MissingValue.CALCULATION_FAILED,
"simple_dtype": MissingValue.CALCULATION_FAILED,
Expand Down Expand Up @@ -115,11 +124,9 @@ def calculate_prompt_token_stats(
}


def calculate_completion_token_stats(
column_config: LLMTextColumnConfig, df: pd.DataFrame
) -> dict[str, float | MissingValue]:
def calculate_completion_token_stats(column_name: str, df: pd.DataFrame) -> dict[str, float | MissingValue]:
try:
tokens_per_record = df[column_config.name].apply(
tokens_per_record = df[column_name].apply(
lambda value: len(TOKENIZER.encode(str(value), disallowed_special=()))
)
return {
Expand All @@ -128,9 +135,7 @@ def calculate_completion_token_stats(
"completion_tokens_stddev": tokens_per_record.std(),
}
except Exception as e:
logger.warning(
f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_config.name}: {e}"
)
logger.warning(f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_name}: {e}")
return {
"completion_tokens_mean": MissingValue.CALCULATION_FAILED,
"completion_tokens_median": MissingValue.CALCULATION_FAILED,
Expand All @@ -141,16 +146,16 @@ def calculate_completion_token_stats(
def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
return {
**calculate_prompt_token_stats(column_config, df),
**calculate_completion_token_stats(column_config, df),
**calculate_completion_token_stats(column_config.name, df),
}


def calculate_validation_column_info(column_config: ValidationColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
def calculate_validation_column_info(column_name: str, df: pd.DataFrame) -> dict[str, Any]:
try:
return {"num_valid_records": df[column_config.name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
return {"num_valid_records": df[column_name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
except Exception as e:
logger.warning(
f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_config.name}: {e}"
f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_name}: {e}"
)
return {"num_valid_records": MissingValue.CALCULATION_FAILED}

Expand All @@ -160,22 +165,33 @@ def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
return f"list[{convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype.value_type)}]"
if isinstance(pyarrow_dtype, pa.StructType):
return "dict"
pyarrow_dtype_str = str(pyarrow_dtype)
if "int" in pyarrow_dtype_str:
return convert_to_simple_dtype(str(pyarrow_dtype))


def convert_to_simple_dtype(dtype: str) -> str:
if "int" in dtype:
return "int"
if "double" in pyarrow_dtype_str:
if "double" in dtype:
return "float"
if "float" in pyarrow_dtype_str:
if "float" in dtype:
return "float"
if "string" in pyarrow_dtype_str:
if "str" in dtype:
return "string"
if "timestamp" in pyarrow_dtype_str:
if "timestamp" in dtype:
return "timestamp"
if "time" in pyarrow_dtype_str:
if "time" in dtype:
return "time"
if "date" in pyarrow_dtype_str:
if "date" in dtype:
return "date"
return pyarrow_dtype_str
return dtype


def get_column_data_type_from_first_non_null_value(column_name: str, df: pd.DataFrame) -> str:
df_no_nulls = df[column_name].dropna()
if len(df_no_nulls) == 0:
return MissingValue.CALCULATION_FAILED
dtype = type(df_no_nulls.iloc[0]).__name__
return convert_to_simple_dtype(dtype)


def ensure_hashable(x: Any) -> str:
Expand Down Expand Up @@ -207,3 +223,7 @@ def ensure_boolean(v: bool | str | int | None) -> bool:
if v is None:
return False
raise ValueError(f"Invalid boolean value: {v}")


def has_pyarrow_backend(df: pd.DataFrame) -> bool:
return all(isinstance(dtype, pd.ArrowDtype) for dtype in df.dtypes)
11 changes: 0 additions & 11 deletions tests/engine/analysis/column_profilers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,6 @@ def test_column_config_with_dataframe_column_not_found_validation_error():
ColumnConfigWithDataFrame(column_config=column_config, df=df)


def test_column_config_with_dataframe_pyarrow_backend_conversion():
df = pd.DataFrame({"test_column": [1, 2, 3]})
column_config = SamplerColumnConfig(
name="test_column", sampler_type=SamplerType.CATEGORY, params={"values": [1, 2, 3]}
)

config_with_df = ColumnConfigWithDataFrame(column_config=column_config, df=df)

assert all(isinstance(dtype, pd.ArrowDtype) for dtype in config_with_df.df.dtypes)


def test_column_config_with_dataframe_as_tuple_method():
df = pd.DataFrame({"test_column": [1, 2, 3]})
column_config = SamplerColumnConfig(
Expand Down
Loading