Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/data_designer/config/analysis/__init__.py

This file was deleted.

15 changes: 8 additions & 7 deletions src/data_designer/config/analysis/column_profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

from abc import ABC
from enum import Enum
from typing import TypeAlias
from typing import Optional, Union

from pydantic import BaseModel, Field
from rich.panel import Panel
from rich.table import Column, Table
from typing_extensions import TypeAlias

from ..base import ConfigBase
from ..utils.visualization import ColorPalette
Expand Down Expand Up @@ -37,20 +38,20 @@ def create_report_section(self) -> Panel:

class JudgeScoreProfilerConfig(ConfigBase):
model_alias: str
summary_score_sample_size: int | None = Field(default=20, ge=1)
summary_score_sample_size: Optional[int] = Field(default=20, ge=1)


class JudgeScoreSample(BaseModel):
score: int | str
score: Union[int, str]
reasoning: str


class JudgeScoreDistributions(BaseModel):
scores: dict[str, list[int | str]]
scores: dict[str, list[Union[int, str]]]
reasoning: dict[str, list[str]]
distribution_types: dict[str, ColumnDistributionType]
distributions: dict[str, CategoricalDistribution | NumericalDistribution | MissingValue]
histograms: dict[str, CategoricalHistogramData | MissingValue]
distributions: dict[str, Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
histograms: dict[str, Union[CategoricalHistogramData, MissingValue]]


class JudgeScoreSummary(BaseModel):
Expand All @@ -62,7 +63,7 @@ class JudgeScoreSummary(BaseModel):
class JudgeScoreProfilerResults(ColumnProfilerResults):
column_name: str
summaries: dict[str, JudgeScoreSummary]
score_distributions: JudgeScoreDistributions | MissingValue
score_distributions: Union[JudgeScoreDistributions, MissingValue]

def create_report_section(self) -> Panel:
layout = Table.grid(Column(), expand=True, padding=(2, 0))
Expand Down
76 changes: 39 additions & 37 deletions src/data_designer/config/analysis/column_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import Annotated, Any, Literal, TypeAlias
from typing import Annotated, Any, Literal, Optional, Union

from pandas import Series
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing_extensions import Self
from typing_extensions import Self, TypeAlias

from ..columns import DataDesignerColumnType
from ..sampler_params import SamplerType
Expand Down Expand Up @@ -39,27 +39,27 @@ def create_report_row_data(self) -> dict[str, str]: ...

class GeneralColumnStatistics(BaseColumnStatistics):
column_name: str
num_records: int | MissingValue
num_null: int | MissingValue
num_unique: int | MissingValue
num_records: Union[int, MissingValue]
num_null: Union[int, MissingValue]
num_unique: Union[int, MissingValue]
pyarrow_dtype: str
simple_dtype: str
column_type: Literal["general"] = "general"

@field_validator("num_null", "num_unique", "num_records", mode="before")
def general_statistics_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
def general_statistics_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)

@property
def percent_null(self) -> float | MissingValue:
def percent_null(self) -> Union[float, MissingValue]:
return (
self.num_null
if self._is_missing_value(self.num_null)
else prepare_number_for_reporting(100 * self.num_null / (self.num_records + EPSILON), float)
)

@property
def percent_unique(self) -> float | MissingValue:
def percent_unique(self) -> Union[float, MissingValue]:
return (
self.num_unique
if self._is_missing_value(self.num_unique)
Expand All @@ -78,17 +78,17 @@ def _general_display_row(self) -> dict[str, str]:
def create_report_row_data(self) -> dict[str, str]:
return self._general_display_row

def _is_missing_value(self, v: float | int | MissingValue) -> bool:
def _is_missing_value(self, v: Union[float, int, MissingValue]) -> bool:
return v in set(MissingValue)


class LLMTextColumnStatistics(GeneralColumnStatistics):
completion_tokens_mean: float | MissingValue
completion_tokens_median: float | MissingValue
completion_tokens_stddev: float | MissingValue
prompt_tokens_mean: float | MissingValue
prompt_tokens_median: float | MissingValue
prompt_tokens_stddev: float | MissingValue
completion_tokens_mean: Union[float, MissingValue]
completion_tokens_median: Union[float, MissingValue]
completion_tokens_stddev: Union[float, MissingValue]
prompt_tokens_mean: Union[float, MissingValue]
prompt_tokens_median: Union[float, MissingValue]
prompt_tokens_stddev: Union[float, MissingValue]
column_type: Literal[DataDesignerColumnType.LLM_TEXT.value] = DataDesignerColumnType.LLM_TEXT.value

@field_validator(
Expand All @@ -100,7 +100,7 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
"prompt_tokens_stddev",
mode="before",
)
def llm_column_ensure_python_floats(cls, v: float | int | MissingValue) -> float | int | MissingValue:
def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> Union[float, int, MissingValue]:
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, float)

def create_report_row_data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -136,7 +136,7 @@ class LLMJudgedColumnStatistics(LLMTextColumnStatistics):
class SamplerColumnStatistics(GeneralColumnStatistics):
sampler_type: SamplerType
distribution_type: ColumnDistributionType
distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None
distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
column_type: Literal[DataDesignerColumnType.SAMPLER.value] = DataDesignerColumnType.SAMPLER.value

def create_report_row_data(self) -> dict[str, str]:
Expand All @@ -148,7 +148,7 @@ def create_report_row_data(self) -> dict[str, str]:

class SeedDatasetColumnStatistics(GeneralColumnStatistics):
distribution_type: ColumnDistributionType
distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None
distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
column_type: Literal[DataDesignerColumnType.SEED_DATASET.value] = DataDesignerColumnType.SEED_DATASET.value

def create_report_row_data(self) -> dict[str, str]:
Expand All @@ -160,15 +160,15 @@ class ExpressionColumnStatistics(GeneralColumnStatistics):


class ValidationColumnStatistics(GeneralColumnStatistics):
num_valid_records: int | MissingValue
num_valid_records: Union[int, MissingValue]
column_type: Literal[DataDesignerColumnType.VALIDATION.value] = DataDesignerColumnType.VALIDATION.value

@field_validator("num_valid_records", mode="before")
def code_validation_column_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
def code_validation_column_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)

@property
def percent_valid(self) -> float | MissingValue:
def percent_valid(self) -> Union[float, MissingValue]:
return (
self.num_valid_records
if self._is_missing_value(self.num_valid_records)
Expand All @@ -181,7 +181,7 @@ def create_report_row_data(self) -> dict[str, str]:


class CategoricalHistogramData(BaseModel):
categories: list[float | int | str]
categories: list[Union[float, int, str]]
counts: list[int]

@model_validator(mode="after")
Expand All @@ -198,12 +198,12 @@ def from_series(cls, series: Series) -> Self:


class CategoricalDistribution(BaseModel):
most_common_value: str | int
least_common_value: str | int
most_common_value: Union[str, int]
least_common_value: Union[str, int]
histogram: CategoricalHistogramData

@field_validator("most_common_value", "least_common_value", mode="before")
def ensure_python_types(cls, v: str | int) -> str | int:
def ensure_python_types(cls, v: Union[str, int]) -> Union[str, int]:
return str(v) if not is_int(v) else prepare_number_for_reporting(v, int)

@classmethod
Expand All @@ -217,14 +217,14 @@ def from_series(cls, series: Series) -> Self:


class NumericalDistribution(BaseModel):
min: float | int
max: float | int
min: Union[float, int]
max: Union[float, int]
mean: float
stddev: float
median: float

@field_validator("min", "max", "mean", "stddev", "median", mode="before")
def ensure_python_types(cls, v: float | int) -> float | int:
def ensure_python_types(cls, v: Union[float, int]) -> Union[float, int]:
return prepare_number_for_reporting(v, int if is_int(v) else float)

@classmethod
Expand All @@ -239,14 +239,16 @@ def from_series(cls, series: Series) -> Self:


ColumnStatisticsT: TypeAlias = Annotated[
GeneralColumnStatistics
| LLMTextColumnStatistics
| LLMCodeColumnStatistics
| LLMStructuredColumnStatistics
| LLMJudgedColumnStatistics
| SamplerColumnStatistics
| SeedDatasetColumnStatistics
| ValidationColumnStatistics
| ExpressionColumnStatistics,
Union[
GeneralColumnStatistics,
LLMTextColumnStatistics,
LLMCodeColumnStatistics,
LLMStructuredColumnStatistics,
LLMJudgedColumnStatistics,
SamplerColumnStatistics,
SeedDatasetColumnStatistics,
ValidationColumnStatistics,
ExpressionColumnStatistics,
],
Field(discriminator="column_type"),
]
9 changes: 5 additions & 4 deletions src/data_designer/config/analysis/dataset_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from functools import cached_property
from pathlib import Path
from typing import Optional, Union

from pydantic import BaseModel, Field, field_validator

Expand All @@ -18,8 +19,8 @@ class DatasetProfilerResults(BaseModel):
num_records: int
target_num_records: int
column_statistics: list[ColumnStatisticsT] = Field(..., min_length=1)
side_effect_column_names: list[str] | None = None
column_profiles: list[ColumnProfilerResultsT] | None = None
side_effect_column_names: Optional[list[str]] = None
column_profiles: Optional[list[ColumnProfilerResultsT]] = None

@field_validator("num_records", "target_num_records", mode="before")
def ensure_python_integers(cls, v: int) -> int:
Expand All @@ -42,8 +43,8 @@ def get_column_statistics_by_type(self, column_type: DataDesignerColumnType) ->

def to_report(
self,
save_path: str | Path | None = None,
include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
save_path: Optional[Union[str, Path]] = None,
include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
) -> None:
"""Generate and print an analysis report based on the dataset profiling results.

Expand Down
3 changes: 0 additions & 3 deletions src/data_designer/config/analysis/utils/__init__.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/data_designer/config/analysis/utils/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Union

from rich.align import Align
from rich.console import Console, Group
Expand Down Expand Up @@ -49,8 +49,8 @@ class ReportSection(str, Enum):

def generate_analysis_report(
analysis: DatasetProfilerResults,
save_path: str | Path | None = None,
include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
save_path: Optional[Union[str, Path]] = None,
include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
) -> None:
"""Generate an analysis report for dataset profiling results.

Expand Down Expand Up @@ -166,7 +166,7 @@ def create_judge_score_summary_table(
layout = Table.grid(Column(), Column(), expand=True, padding=(0, 2))

histogram_table = create_rich_histogram_table(
{str(s): c for s, c in zip(histogram.categories, histogram.counts, strict=False)},
{str(s): c for s, c in zip(histogram.categories, histogram.counts)},
("score", "count"),
name_style=HIST_NAME_STYLE,
value_style=HIST_VALUE_STYLE,
Expand Down
8 changes: 4 additions & 4 deletions src/data_designer/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar, Union

import pandas as pd
from pydantic import BaseModel, ConfigDict
Expand All @@ -14,9 +14,9 @@
from .utils.io_helpers import serialize_data

if TYPE_CHECKING:
from ..client.results.preview import PreviewResults
from .analysis.dataset_profiler import DatasetProfilerResults
from .config_builder import DataDesignerConfigBuilder
from .preview_results import PreviewResults

DEFAULT_NUM_RECORDS = 10

Expand Down Expand Up @@ -66,7 +66,7 @@ def to_dict(self) -> dict[str, Any]:
"""
return self.model_dump(mode="json")

def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
"""Convert the configuration to a YAML string or file.

Args:
Expand All @@ -84,7 +84,7 @@ def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **k
with open(path, "w") as f:
f.write(yaml_str)

def to_json(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
def to_json(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
"""Convert the configuration to a JSON string or file.

Args:
Expand Down
Loading