Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,19 @@ required-version = ">=0.7.10"
[tool.ruff]
line-length = 120
indent-width = 4
target-version = "py310"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we target python 3.10


[tool.ruff.lint]
select = [
# "E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort (import sorting)
# "N", # pep8-naming
# "UP", # pyupgrade (modern Python syntax)
# "ANN", # flake8-annotations (enforce type hints)
# "B", # fla e8-bugbear (common bugs)
# "C4", # flake8-comprehensions
# "DTZ", # flake8-datetimez (timezone awareness)
"ICN", # flake8-import-conventions
"PIE", # flake8-pie (misc lints)
# "RET", # flake8-return
# "SIM", # flake8-simplify
# "PTH", # flake8-use-pathlib
"TID", # flake8-tidy-imports (ban relative imports)
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort (import sorting)
"ICN", # flake8-import-conventions
"PIE", # flake8-pie (misc lints)
"TID", # flake8-tidy-imports (ban relative imports)
"UP006", # List[A] -> list[A]
"UP007", # Union[A, B] -> A | B
"UP045", # Optional[A] -> A | None
]
ignore = [
"ANN401", # Dynamically typed expressions (Any)
Expand Down
13 changes: 6 additions & 7 deletions src/data_designer/config/analysis/column_profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from abc import ABC
from enum import Enum
from typing import Optional, Union

from pydantic import BaseModel, Field
from rich.panel import Panel
Expand Down Expand Up @@ -61,7 +60,7 @@ class JudgeScoreProfilerConfig(ConfigBase):
"""

model_alias: str
summary_score_sample_size: Optional[int] = Field(default=20, ge=1)
summary_score_sample_size: int | None = Field(default=20, ge=1)


class JudgeScoreSample(BaseModel):
Expand All @@ -75,7 +74,7 @@ class JudgeScoreSample(BaseModel):
reasoning: The reasoning or explanation provided by the judge for this score.
"""

score: Union[int, str]
score: int | str
reasoning: str


Expand All @@ -94,11 +93,11 @@ class JudgeScoreDistributions(BaseModel):
histograms: Mapping of each score dimension name to its histogram data.
"""

scores: dict[str, list[Union[int, str]]]
scores: dict[str, list[int | str]]
reasoning: dict[str, list[str]]
distribution_types: dict[str, ColumnDistributionType]
distributions: dict[str, Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
histograms: dict[str, Union[CategoricalHistogramData, MissingValue]]
distributions: dict[str, CategoricalDistribution | NumericalDistribution | MissingValue]
histograms: dict[str, CategoricalHistogramData | MissingValue]


class JudgeScoreSummary(BaseModel):
Expand Down Expand Up @@ -132,7 +131,7 @@ class JudgeScoreProfilerResults(ColumnProfilerResults):

column_name: str
summaries: dict[str, JudgeScoreSummary]
score_distributions: Union[JudgeScoreDistributions, MissingValue]
score_distributions: JudgeScoreDistributions | MissingValue

def create_report_section(self) -> Panel:
layout = Table.grid(Column(), expand=True, padding=(2, 0))
Expand Down
74 changes: 37 additions & 37 deletions src/data_designer/config/analysis/column_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Literal, Optional, Union
from typing import Any, Literal

from pandas import Series
from pydantic import BaseModel, ConfigDict, create_model, field_validator, model_validator
Expand Down Expand Up @@ -69,27 +69,27 @@ class GeneralColumnStatistics(BaseColumnStatistics):
"""

column_name: str
num_records: Union[int, MissingValue]
num_null: Union[int, MissingValue]
num_unique: Union[int, MissingValue]
num_records: int | MissingValue
num_null: int | MissingValue
num_unique: int | MissingValue
pyarrow_dtype: str
simple_dtype: str
column_type: Literal["general"] = "general"

@field_validator("num_null", "num_unique", "num_records", mode="before")
def general_statistics_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
def general_statistics_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)

@property
def percent_null(self) -> Union[float, MissingValue]:
def percent_null(self) -> float | MissingValue:
return (
self.num_null
if self._is_missing_value(self.num_null)
else prepare_number_for_reporting(100 * self.num_null / (self.num_records + EPSILON), float)
)

@property
def percent_unique(self) -> Union[float, MissingValue]:
def percent_unique(self) -> float | MissingValue:
return (
self.num_unique
if self._is_missing_value(self.num_unique)
Expand All @@ -108,7 +108,7 @@ def _general_display_row(self) -> dict[str, str]:
def create_report_row_data(self) -> dict[str, str]:
return self._general_display_row

def _is_missing_value(self, v: Union[float, int, MissingValue]) -> bool:
def _is_missing_value(self, v: float | int | MissingValue) -> bool:
return v in set(MissingValue)


Expand All @@ -128,12 +128,12 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
column_type: Discriminator field, always "llm-text" for this statistics type.
"""

output_tokens_mean: Union[float, MissingValue]
output_tokens_median: Union[float, MissingValue]
output_tokens_stddev: Union[float, MissingValue]
input_tokens_mean: Union[float, MissingValue]
input_tokens_median: Union[float, MissingValue]
input_tokens_stddev: Union[float, MissingValue]
output_tokens_mean: float | MissingValue
output_tokens_median: float | MissingValue
output_tokens_stddev: float | MissingValue
input_tokens_mean: float | MissingValue
input_tokens_median: float | MissingValue
input_tokens_stddev: float | MissingValue
column_type: Literal[DataDesignerColumnType.LLM_TEXT.value] = DataDesignerColumnType.LLM_TEXT.value

@field_validator(
Expand All @@ -145,7 +145,7 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
"input_tokens_stddev",
mode="before",
)
def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> Union[float, int, MissingValue]:
def llm_column_ensure_python_floats(cls, v: float | int | MissingValue) -> float | int | MissingValue:
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, float)

def create_report_row_data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -225,7 +225,7 @@ class SamplerColumnStatistics(GeneralColumnStatistics):

sampler_type: SamplerType
distribution_type: ColumnDistributionType
distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None
column_type: Literal[DataDesignerColumnType.SAMPLER.value] = DataDesignerColumnType.SAMPLER.value

def create_report_row_data(self) -> dict[str, str]:
Expand Down Expand Up @@ -273,15 +273,15 @@ class ValidationColumnStatistics(GeneralColumnStatistics):
column_type: Discriminator field, always "validation" for this statistics type.
"""

num_valid_records: Union[int, MissingValue]
num_valid_records: int | MissingValue
column_type: Literal[DataDesignerColumnType.VALIDATION.value] = DataDesignerColumnType.VALIDATION.value

@field_validator("num_valid_records", mode="before")
def code_validation_column_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
def code_validation_column_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)

@property
def percent_valid(self) -> Union[float, MissingValue]:
def percent_valid(self) -> float | MissingValue:
return (
self.num_valid_records
if self._is_missing_value(self.num_valid_records)
Expand All @@ -303,7 +303,7 @@ class CategoricalHistogramData(BaseModel):
counts: List of occurrence counts for each category.
"""

categories: list[Union[float, int, str]]
categories: list[float | int | str]
counts: list[int]

@model_validator(mode="after")
Expand All @@ -328,12 +328,12 @@ class CategoricalDistribution(BaseModel):
histogram: Complete frequency distribution showing all categories and their counts.
"""

most_common_value: Union[str, int]
least_common_value: Union[str, int]
most_common_value: str | int
least_common_value: str | int
histogram: CategoricalHistogramData

@field_validator("most_common_value", "least_common_value", mode="before")
def ensure_python_types(cls, v: Union[str, int]) -> Union[str, int]:
def ensure_python_types(cls, v: str | int) -> str | int:
return str(v) if not is_int(v) else prepare_number_for_reporting(v, int)

@classmethod
Expand All @@ -357,14 +357,14 @@ class NumericalDistribution(BaseModel):
median: Median value of the distribution.
"""

min: Union[float, int]
max: Union[float, int]
min: float | int
max: float | int
mean: float
stddev: float
median: float

@field_validator("min", "max", "mean", "stddev", "median", mode="before")
def ensure_python_types(cls, v: Union[float, int]) -> Union[float, int]:
def ensure_python_types(cls, v: float | int) -> float | int:
return prepare_number_for_reporting(v, int if is_int(v) else float)

@classmethod
Expand All @@ -378,17 +378,17 @@ def from_series(cls, series: Series) -> Self:
)


ColumnStatisticsT: TypeAlias = Union[
GeneralColumnStatistics,
LLMTextColumnStatistics,
LLMCodeColumnStatistics,
LLMStructuredColumnStatistics,
LLMJudgedColumnStatistics,
SamplerColumnStatistics,
SeedDatasetColumnStatistics,
ValidationColumnStatistics,
ExpressionColumnStatistics,
]
ColumnStatisticsT: TypeAlias = (
GeneralColumnStatistics
| LLMTextColumnStatistics
| LLMCodeColumnStatistics
| LLMStructuredColumnStatistics
| LLMJudgedColumnStatistics
| SamplerColumnStatistics
| SeedDatasetColumnStatistics
| ValidationColumnStatistics
| ExpressionColumnStatistics
)


DEFAULT_COLUMN_STATISTICS_MAP = {
Expand Down
10 changes: 5 additions & 5 deletions src/data_designer/config/analysis/dataset_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from functools import cached_property
from pathlib import Path
from typing import Annotated, Optional, Union
from typing import Annotated

from pydantic import BaseModel, Field, field_validator

Expand Down Expand Up @@ -34,8 +34,8 @@ class DatasetProfilerResults(BaseModel):
num_records: int
target_num_records: int
column_statistics: list[Annotated[ColumnStatisticsT, Field(discriminator="column_type")]] = Field(..., min_length=1)
side_effect_column_names: Optional[list[str]] = None
column_profiles: Optional[list[ColumnProfilerResultsT]] = None
side_effect_column_names: list[str] | None = None
column_profiles: list[ColumnProfilerResultsT] | None = None

@field_validator("num_records", "target_num_records", mode="before")
def ensure_python_integers(cls, v: int) -> int:
Expand All @@ -61,8 +61,8 @@ def get_column_statistics_by_type(self, column_type: DataDesignerColumnType) ->

def to_report(
self,
save_path: Optional[Union[str, Path]] = None,
include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
save_path: str | Path | None = None,
include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
) -> None:
"""Generate and print an analysis report based on the dataset profiling results.

Expand Down
6 changes: 3 additions & 3 deletions src/data_designer/config/analysis/utils/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING

from rich.align import Align
from rich.console import Console, Group
Expand Down Expand Up @@ -48,8 +48,8 @@ class ReportSection(str, Enum):

def generate_analysis_report(
analysis: DatasetProfilerResults,
save_path: Optional[Union[str, Path]] = None,
include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
save_path: str | Path | None = None,
include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
) -> None:
"""Generate an analysis report for dataset profiling results.

Expand Down
6 changes: 3 additions & 3 deletions src/data_designer/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Optional, Union
from typing import Any

import yaml
from pydantic import BaseModel, ConfigDict
Expand All @@ -31,7 +31,7 @@ def to_dict(self) -> dict[str, Any]:
"""
return self.model_dump(mode="json")

def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
"""Convert the configuration to a YAML string or file.

Args:
Expand All @@ -49,7 +49,7 @@ def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[i
with open(path, "w") as f:
f.write(yaml_str)

def to_json(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
def to_json(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
"""Convert the configuration to a JSON string or file.

Args:
Expand Down
12 changes: 6 additions & 6 deletions src/data_designer/config/column_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC
from typing import Annotated, Literal, Optional, Type, Union
from typing import Annotated, Literal

from pydantic import BaseModel, Discriminator, Field, model_validator
from typing_extensions import Self
Expand Down Expand Up @@ -91,7 +91,7 @@ class SamplerColumnConfig(SingleColumnConfig):
sampler_type: SamplerType
params: Annotated[SamplerParamsT, Discriminator("sampler_type")]
conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {}
convert_to: Optional[str] = None
convert_to: str | None = None
column_type: Literal["sampler"] = "sampler"

@model_validator(mode="before")
Expand Down Expand Up @@ -146,8 +146,8 @@ class LLMTextColumnConfig(SingleColumnConfig):

prompt: str
model_alias: str
system_prompt: Optional[str] = None
multi_modal_context: Optional[list[ImageContext]] = None
system_prompt: str | None = None
multi_modal_context: list[ImageContext] | None = None
column_type: Literal["llm-text"] = "llm-text"

@property
Expand Down Expand Up @@ -222,7 +222,7 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig):
column_type: Discriminator field, always "llm-structured" for this configuration type.
"""

output_format: Union[dict, Type[BaseModel]]
output_format: dict | type[BaseModel]
column_type: Literal["llm-structured"] = "llm-structured"

@model_validator(mode="after")
Expand Down Expand Up @@ -255,7 +255,7 @@ class Score(ConfigBase):

name: str = Field(..., description="A clear name for this score.")
description: str = Field(..., description="An informative and detailed assessment guide for using this score.")
options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.")
options: dict[int | str, str] = Field(..., description="Score options in the format of {score: description}.")


class LLMJudgeColumnConfig(LLMTextColumnConfig):
Expand Down
Loading