diff --git a/pyproject.toml b/pyproject.toml index 49a6de9b..bdd783b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,25 +105,19 @@ required-version = ">=0.7.10" [tool.ruff] line-length = 120 indent-width = 4 +target-version = "py310" [tool.ruff.lint] select = [ - # "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - "I", # isort (import sorting) - # "N", # pep8-naming - # "UP", # pyupgrade (modern Python syntax) - # "ANN", # flake8-annotations (enforce type hints) - # "B", # fla e8-bugbear (common bugs) - # "C4", # flake8-comprehensions - # "DTZ", # flake8-datetimez (timezone awareness) - "ICN", # flake8-import-conventions - "PIE", # flake8-pie (misc lints) - # "RET", # flake8-return - # "SIM", # flake8-simplify - # "PTH", # flake8-use-pathlib - "TID", # flake8-tidy-imports (ban relative imports) + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort (import sorting) + "ICN", # flake8-import-conventions + "PIE", # flake8-pie (misc lints) + "TID", # flake8-tidy-imports (ban relative imports) + "UP006", # List[A] -> list[A] + "UP007", # Union[A, B] -> A | B + "UP045", # Optional[A] -> A | None ] ignore = [ "ANN401", # Dynamically typed expressions (Any) diff --git a/src/data_designer/config/analysis/column_profilers.py b/src/data_designer/config/analysis/column_profilers.py index bf5437f8..a1d1d165 100644 --- a/src/data_designer/config/analysis/column_profilers.py +++ b/src/data_designer/config/analysis/column_profilers.py @@ -3,7 +3,6 @@ from abc import ABC from enum import Enum -from typing import Optional, Union from pydantic import BaseModel, Field from rich.panel import Panel @@ -61,7 +60,7 @@ class JudgeScoreProfilerConfig(ConfigBase): """ model_alias: str - summary_score_sample_size: Optional[int] = Field(default=20, ge=1) + summary_score_sample_size: int | None = Field(default=20, ge=1) class JudgeScoreSample(BaseModel): @@ -75,7 +74,7 @@ class JudgeScoreSample(BaseModel): reasoning: The reasoning or explanation provided by the judge for this score. """ - score: Union[int, str] + score: int | str reasoning: str @@ -94,11 +93,11 @@ class JudgeScoreDistributions(BaseModel): histograms: Mapping of each score dimension name to its histogram data. """ - scores: dict[str, list[Union[int, str]]] + scores: dict[str, list[int | str]] reasoning: dict[str, list[str]] distribution_types: dict[str, ColumnDistributionType] - distributions: dict[str, Union[CategoricalDistribution, NumericalDistribution, MissingValue]] - histograms: dict[str, Union[CategoricalHistogramData, MissingValue]] + distributions: dict[str, CategoricalDistribution | NumericalDistribution | MissingValue] + histograms: dict[str, CategoricalHistogramData | MissingValue] class JudgeScoreSummary(BaseModel): @@ -132,7 +131,7 @@ class JudgeScoreProfilerResults(ColumnProfilerResults): column_name: str summaries: dict[str, JudgeScoreSummary] - score_distributions: Union[JudgeScoreDistributions, MissingValue] + score_distributions: JudgeScoreDistributions | MissingValue def create_report_section(self) -> Panel: layout = Table.grid(Column(), expand=True, padding=(2, 0)) diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py index c0aa46b6..cdb88696 100644 --- a/src/data_designer/config/analysis/column_statistics.py +++ b/src/data_designer/config/analysis/column_statistics.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from pandas import Series from pydantic import BaseModel, ConfigDict, create_model, field_validator, model_validator @@ -69,19 +69,19 @@ class GeneralColumnStatistics(BaseColumnStatistics): """ column_name: str - num_records: Union[int, MissingValue] - num_null: Union[int, MissingValue] - num_unique: Union[int, MissingValue] + num_records: int | MissingValue + num_null: int | MissingValue + num_unique: int | MissingValue pyarrow_dtype: str simple_dtype: str column_type: Literal["general"] = "general" @field_validator("num_null", "num_unique", "num_records", mode="before") - def general_statistics_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]: + def general_statistics_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue: return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int) @property - def percent_null(self) -> Union[float, MissingValue]: + def percent_null(self) -> float | MissingValue: return ( self.num_null if self._is_missing_value(self.num_null) @@ -89,7 +89,7 @@ def percent_null(self) -> Union[float, MissingValue]: ) @property - def percent_unique(self) -> Union[float, MissingValue]: + def percent_unique(self) -> float | MissingValue: return ( self.num_unique if self._is_missing_value(self.num_unique) @@ -108,7 +108,7 @@ def _general_display_row(self) -> dict[str, str]: def create_report_row_data(self) -> dict[str, str]: return self._general_display_row - def _is_missing_value(self, v: Union[float, int, MissingValue]) -> bool: + def _is_missing_value(self, v: float | int | MissingValue) -> bool: return v in set(MissingValue) @@ -128,12 +128,12 @@ class LLMTextColumnStatistics(GeneralColumnStatistics): column_type: Discriminator field, always "llm-text" for this statistics type. """ - output_tokens_mean: Union[float, MissingValue] - output_tokens_median: Union[float, MissingValue] - output_tokens_stddev: Union[float, MissingValue] - input_tokens_mean: Union[float, MissingValue] - input_tokens_median: Union[float, MissingValue] - input_tokens_stddev: Union[float, MissingValue] + output_tokens_mean: float | MissingValue + output_tokens_median: float | MissingValue + output_tokens_stddev: float | MissingValue + input_tokens_mean: float | MissingValue + input_tokens_median: float | MissingValue + input_tokens_stddev: float | MissingValue column_type: Literal[DataDesignerColumnType.LLM_TEXT.value] = DataDesignerColumnType.LLM_TEXT.value @field_validator( @@ -145,7 +145,7 @@ class LLMTextColumnStatistics(GeneralColumnStatistics): "input_tokens_stddev", mode="before", ) - def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> Union[float, int, MissingValue]: + def llm_column_ensure_python_floats(cls, v: float | int | MissingValue) -> float | int | MissingValue: return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, float) def create_report_row_data(self) -> dict[str, Any]: @@ -225,7 +225,7 @@ class SamplerColumnStatistics(GeneralColumnStatistics): sampler_type: SamplerType distribution_type: ColumnDistributionType - distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]] + distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None column_type: Literal[DataDesignerColumnType.SAMPLER.value] = DataDesignerColumnType.SAMPLER.value def create_report_row_data(self) -> dict[str, str]: @@ -273,15 +273,15 @@ class ValidationColumnStatistics(GeneralColumnStatistics): column_type: Discriminator field, always "validation" for this statistics type. """ - num_valid_records: Union[int, MissingValue] + num_valid_records: int | MissingValue column_type: Literal[DataDesignerColumnType.VALIDATION.value] = DataDesignerColumnType.VALIDATION.value @field_validator("num_valid_records", mode="before") - def code_validation_column_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]: + def code_validation_column_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue: return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int) @property - def percent_valid(self) -> Union[float, MissingValue]: + def percent_valid(self) -> float | MissingValue: return ( self.num_valid_records if self._is_missing_value(self.num_valid_records) @@ -303,7 +303,7 @@ class CategoricalHistogramData(BaseModel): counts: List of occurrence counts for each category. """ - categories: list[Union[float, int, str]] + categories: list[float | int | str] counts: list[int] @model_validator(mode="after") @@ -328,12 +328,12 @@ class CategoricalDistribution(BaseModel): histogram: Complete frequency distribution showing all categories and their counts. """ - most_common_value: Union[str, int] - least_common_value: Union[str, int] + most_common_value: str | int + least_common_value: str | int histogram: CategoricalHistogramData @field_validator("most_common_value", "least_common_value", mode="before") - def ensure_python_types(cls, v: Union[str, int]) -> Union[str, int]: + def ensure_python_types(cls, v: str | int) -> str | int: return str(v) if not is_int(v) else prepare_number_for_reporting(v, int) @classmethod @@ -357,14 +357,14 @@ class NumericalDistribution(BaseModel): median: Median value of the distribution. """ - min: Union[float, int] - max: Union[float, int] + min: float | int + max: float | int mean: float stddev: float median: float @field_validator("min", "max", "mean", "stddev", "median", mode="before") - def ensure_python_types(cls, v: Union[float, int]) -> Union[float, int]: + def ensure_python_types(cls, v: float | int) -> float | int: return prepare_number_for_reporting(v, int if is_int(v) else float) @classmethod @@ -378,17 +378,17 @@ def from_series(cls, series: Series) -> Self: ) -ColumnStatisticsT: TypeAlias = Union[ - GeneralColumnStatistics, - LLMTextColumnStatistics, - LLMCodeColumnStatistics, - LLMStructuredColumnStatistics, - LLMJudgedColumnStatistics, - SamplerColumnStatistics, - SeedDatasetColumnStatistics, - ValidationColumnStatistics, - ExpressionColumnStatistics, -] +ColumnStatisticsT: TypeAlias = ( + GeneralColumnStatistics + | LLMTextColumnStatistics + | LLMCodeColumnStatistics + | LLMStructuredColumnStatistics + | LLMJudgedColumnStatistics + | SamplerColumnStatistics + | SeedDatasetColumnStatistics + | ValidationColumnStatistics + | ExpressionColumnStatistics +) DEFAULT_COLUMN_STATISTICS_MAP = { diff --git a/src/data_designer/config/analysis/dataset_profiler.py b/src/data_designer/config/analysis/dataset_profiler.py index f0976293..704fa48d 100644 --- a/src/data_designer/config/analysis/dataset_profiler.py +++ b/src/data_designer/config/analysis/dataset_profiler.py @@ -3,7 +3,7 @@ from functools import cached_property from pathlib import Path -from typing import Annotated, Optional, Union +from typing import Annotated from pydantic import BaseModel, Field, field_validator @@ -34,8 +34,8 @@ class DatasetProfilerResults(BaseModel): num_records: int target_num_records: int column_statistics: list[Annotated[ColumnStatisticsT, Field(discriminator="column_type")]] = Field(..., min_length=1) - side_effect_column_names: Optional[list[str]] = None - column_profiles: Optional[list[ColumnProfilerResultsT]] = None + side_effect_column_names: list[str] | None = None + column_profiles: list[ColumnProfilerResultsT] | None = None @field_validator("num_records", "target_num_records", mode="before") def ensure_python_integers(cls, v: int) -> int: @@ -61,8 +61,8 @@ def get_column_statistics_by_type(self, column_type: DataDesignerColumnType) -> def to_report( self, - save_path: Optional[Union[str, Path]] = None, - include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None, + save_path: str | Path | None = None, + include_sections: list[ReportSection | DataDesignerColumnType] | None = None, ) -> None: """Generate and print an analysis report based on the dataset profiling results. diff --git a/src/data_designer/config/analysis/utils/reporting.py b/src/data_designer/config/analysis/utils/reporting.py index 72899faf..0ad48c06 100644 --- a/src/data_designer/config/analysis/utils/reporting.py +++ b/src/data_designer/config/analysis/utils/reporting.py @@ -5,7 +5,7 @@ from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from rich.align import Align from rich.console import Console, Group @@ -48,8 +48,8 @@ class ReportSection(str, Enum): def generate_analysis_report( analysis: DatasetProfilerResults, - save_path: Optional[Union[str, Path]] = None, - include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None, + save_path: str | Path | None = None, + include_sections: list[ReportSection | DataDesignerColumnType] | None = None, ) -> None: """Generate an analysis report for dataset profiling results. diff --git a/src/data_designer/config/base.py b/src/data_designer/config/base.py index 0f3bbf5f..365e9428 100644 --- a/src/data_designer/config/base.py +++ b/src/data_designer/config/base.py @@ -4,7 +4,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import yaml from pydantic import BaseModel, ConfigDict @@ -31,7 +31,7 @@ def to_dict(self) -> dict[str, Any]: """ return self.model_dump(mode="json") - def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]: + def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None: """Convert the configuration to a YAML string or file. Args: @@ -49,7 +49,7 @@ def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[i with open(path, "w") as f: f.write(yaml_str) - def to_json(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]: + def to_json(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None: """Convert the configuration to a JSON string or file. Args: diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py index 48fe529e..f38931c9 100644 --- a/src/data_designer/config/column_configs.py +++ b/src/data_designer/config/column_configs.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC -from typing import Annotated, Literal, Optional, Type, Union +from typing import Annotated, Literal from pydantic import BaseModel, Discriminator, Field, model_validator from typing_extensions import Self @@ -91,7 +91,7 @@ class SamplerColumnConfig(SingleColumnConfig): sampler_type: SamplerType params: Annotated[SamplerParamsT, Discriminator("sampler_type")] conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {} - convert_to: Optional[str] = None + convert_to: str | None = None column_type: Literal["sampler"] = "sampler" @model_validator(mode="before") @@ -146,8 +146,8 @@ class LLMTextColumnConfig(SingleColumnConfig): prompt: str model_alias: str - system_prompt: Optional[str] = None - multi_modal_context: Optional[list[ImageContext]] = None + system_prompt: str | None = None + multi_modal_context: list[ImageContext] | None = None column_type: Literal["llm-text"] = "llm-text" @property @@ -222,7 +222,7 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig): column_type: Discriminator field, always "llm-structured" for this configuration type. """ - output_format: Union[dict, Type[BaseModel]] + output_format: dict | type[BaseModel] column_type: Literal["llm-structured"] = "llm-structured" @model_validator(mode="after") @@ -255,7 +255,7 @@ class Score(ConfigBase): name: str = Field(..., description="A clear name for this score.") description: str = Field(..., description="An informative and detailed assessment guide for using this score.") - options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.") + options: dict[int | str, str] = Field(..., description="Score options in the format of {score: description}.") class LLMJudgeColumnConfig(LLMTextColumnConfig): diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index cbfce4f7..7583d4ac 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Union from typing_extensions import TypeAlias @@ -27,17 +26,17 @@ plugin_manager = PluginManager() -ColumnConfigT: TypeAlias = Union[ - ExpressionColumnConfig, - LLMCodeColumnConfig, - LLMJudgeColumnConfig, - LLMStructuredColumnConfig, - LLMTextColumnConfig, - SamplerColumnConfig, - SeedDatasetColumnConfig, - ValidationColumnConfig, - EmbeddingColumnConfig, -] +ColumnConfigT: TypeAlias = ( + ExpressionColumnConfig + | LLMCodeColumnConfig + | LLMJudgeColumnConfig + | LLMStructuredColumnConfig + | LLMTextColumnConfig + | SamplerColumnConfig + | SeedDatasetColumnConfig + | ValidationColumnConfig + | EmbeddingColumnConfig +) ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT) DataDesignerColumnType = create_str_enum_from_discriminated_type_union( @@ -63,7 +62,7 @@ ) -def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool: +def column_type_used_in_execution_dag(column_type: str | DataDesignerColumnType) -> bool: """Return True if the column type is used in the workflow execution DAG.""" column_type = resolve_string_enum(column_type, DataDesignerColumnType) dag_column_types = { @@ -79,7 +78,7 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn return column_type in dag_column_types -def column_type_is_model_generated(column_type: Union[str, DataDesignerColumnType]) -> bool: +def column_type_is_model_generated(column_type: str | DataDesignerColumnType) -> bool: """Return True if the column type is a model-generated column.""" column_type = resolve_string_enum(column_type, DataDesignerColumnType) model_generated_column_types = { diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index db382f88..1dff587c 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -6,7 +6,6 @@ import json import logging from pathlib import Path -from typing import Optional, Union from pygments import highlight from pygments.formatters import HtmlFormatter @@ -69,7 +68,7 @@ class BuilderConfig(ExportableConfigBase): """ data_designer: DataDesignerConfig - datastore_settings: Optional[DatastoreSettings] + datastore_settings: DatastoreSettings | None class DataDesignerConfigBuilder: @@ -79,7 +78,7 @@ class DataDesignerConfigBuilder: """ @classmethod - def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self: + def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self: """Create a DataDesignerConfigBuilder from an existing configuration. Args: @@ -130,7 +129,7 @@ def from_config(cls, config: Union[dict, str, Path, BuilderConfig]) -> Self: return builder - def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] = None): + def __init__(self, model_configs: list[ModelConfig] | str | Path | None = None): """Initialize a new DataDesignerConfigBuilder instance. Args: @@ -142,10 +141,10 @@ def __init__(self, model_configs: Optional[Union[list[ModelConfig], str, Path]] self._column_configs = {} self._model_configs = _load_model_configs(model_configs) self._processor_configs: list[ProcessorConfigT] = [] - self._seed_config: Optional[SeedConfig] = None + self._seed_config: SeedConfig | None = None self._constraints: list[ColumnConstraintT] = [] self._profilers: list[ColumnProfilerConfigT] = [] - self._datastore_settings: Optional[DatastoreSettings] = None + self._datastore_settings: DatastoreSettings | None = None @property def model_configs(self) -> list[ModelConfig]: @@ -206,10 +205,10 @@ def delete_model_config(self, alias: str) -> Self: def add_column( self, - column_config: Optional[ColumnConfigT] = None, + column_config: ColumnConfigT | None = None, *, - name: Optional[str] = None, - column_type: Optional[DataDesignerColumnType] = None, + name: str | None = None, + column_type: DataDesignerColumnType | None = None, **kwargs, ) -> Self: """Add a Data Designer column configuration to the current Data Designer configuration. @@ -246,9 +245,9 @@ def add_column( def add_constraint( self, - constraint: Optional[ColumnConstraintT] = None, + constraint: ColumnConstraintT | None = None, *, - constraint_type: Optional[ConstraintType] = None, + constraint_type: ConstraintType | None = None, **kwargs, ) -> Self: """Add a constraint to the current Data Designer configuration. @@ -298,9 +297,9 @@ def add_constraint( def add_processor( self, - processor_config: Optional[ProcessorConfigT] = None, + processor_config: ProcessorConfigT | None = None, *, - processor_type: Optional[ProcessorType] = None, + processor_type: ProcessorType | None = None, **kwargs, ) -> Self: """Add a processor to the current Data Designer configuration. @@ -495,7 +494,7 @@ def get_processor_configs(self) -> dict[BuildStage, list[ProcessorConfigT]]: """ return self._processor_configs - def get_seed_config(self) -> Optional[SeedConfig]: + def get_seed_config(self) -> SeedConfig | None: """Get the seed config for the current Data Designer configuration. Returns: @@ -503,7 +502,7 @@ def get_seed_config(self) -> Optional[SeedConfig]: """ return self._seed_config - def get_seed_datastore_settings(self) -> Optional[DatastoreSettings]: + def get_seed_datastore_settings(self) -> DatastoreSettings | None: """Get most recent datastore settings for the current Data Designer configuration. Returns: @@ -522,7 +521,7 @@ def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int: """ return len(self.get_columns_of_type(column_type)) - def set_seed_datastore_settings(self, datastore_settings: Optional[DatastoreSettings]) -> Self: + def set_seed_datastore_settings(self, datastore_settings: DatastoreSettings | None) -> Self: """Set the datastore settings for the seed dataset. Args: @@ -563,7 +562,7 @@ def with_seed_dataset( dataset_reference: SeedDatasetReference, *, sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED, - selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None, + selection_strategy: IndexRange | PartitionBlock | None = None, ) -> Self: """Add a seed dataset to the current Data Designer configuration. @@ -591,7 +590,7 @@ def with_seed_dataset( self._column_configs[column_name] = SeedDatasetColumnConfig(name=column_name) return self - def write_config(self, path: Union[str, Path], indent: Optional[int] = 2, **kwargs) -> None: + def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> None: """Write the current configuration to a file. Args: @@ -662,7 +661,7 @@ def _repr_html_(self) -> str: return REPR_HTML_TEMPLATE.format(css=css, highlighted_html=highlighted_html) -def _load_model_configs(model_configs: Optional[Union[list[ModelConfig], str, Path]] = None) -> list[ModelConfig]: +def _load_model_configs(model_configs: list[ModelConfig] | str | Path | None = None) -> list[ModelConfig]: """Resolves the provided model_configs, which may be a string or Path to a model configuration file. If None or empty, returns default model configurations if possible, otherwise raises an error. """ diff --git a/src/data_designer/config/data_designer_config.py b/src/data_designer/config/data_designer_config.py index d90deb41..f3c0125b 100644 --- a/src/data_designer/config/data_designer_config.py +++ b/src/data_designer/config/data_designer_config.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Annotated, Optional +from typing import Annotated from pydantic import Field @@ -33,8 +33,8 @@ class DataDesignerConfig(ExportableConfigBase): """ columns: list[Annotated[ColumnConfigT, Field(discriminator="column_type")]] = Field(min_length=1) - model_configs: Optional[list[ModelConfig]] = None - seed_config: Optional[SeedConfig] = None - constraints: Optional[list[ColumnConstraintT]] = None - profilers: Optional[list[ColumnProfilerConfigT]] = None - processors: Optional[list[Annotated[ProcessorConfigT, Field(discriminator="processor_type")]]] = None + model_configs: list[ModelConfig] | None = None + seed_config: SeedConfig | None = None + constraints: list[ColumnConstraintT] | None = None + profilers: list[ColumnProfilerConfigT] | None = None + processors: list[Annotated[ProcessorConfigT, Field(discriminator="processor_type")]] | None = None diff --git a/src/data_designer/config/datastore.py b/src/data_designer/config/datastore.py index f29b0292..ab78bae4 100644 --- a/src/data_designer/config/datastore.py +++ b/src/data_designer/config/datastore.py @@ -5,7 +5,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import pandas as pd import pyarrow.parquet as pq @@ -28,10 +28,10 @@ class DatastoreSettings(BaseModel): ..., description="Datastore endpoint. Use 'https://huggingface.co' for the Hugging Face Hub.", ) - token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.") + token: str | None = Field(default=None, description="If needed, token to use for authentication.") -def get_file_column_names(file_reference: Union[str, Path, HfFileSystem], file_type: str) -> list[str]: +def get_file_column_names(file_reference: str | Path | HfFileSystem, file_type: str) -> list[str]: """Get column names from a dataset file. Args: @@ -80,7 +80,7 @@ def fetch_seed_dataset_column_names(seed_dataset_reference: SeedDatasetReference def fetch_seed_dataset_column_names_from_datastore( repo_id: str, filename: str, - datastore_settings: Optional[Union[DatastoreSettings, dict]] = None, + datastore_settings: DatastoreSettings | dict | None = None, ) -> list[str]: file_type = filename.split(".")[-1] if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS: @@ -115,7 +115,7 @@ def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | No def upload_to_hf_hub( - dataset_path: Union[str, Path], + dataset_path: str | Path, filename: str, repo_id: str, datastore_settings: DatastoreSettings, @@ -171,7 +171,7 @@ def _extract_single_file_path_from_glob_pattern_if_present( return matching_files[0] -def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path: +def _validate_dataset_path(dataset_path: str | Path, allow_glob_pattern: bool = False) -> Path: if allow_glob_pattern and "*" in str(dataset_path): parts = str(dataset_path).split("*.") file_path = parts[0] diff --git a/src/data_designer/config/default_model_settings.py b/src/data_designer/config/default_model_settings.py index c86c7f80..fd71da68 100644 --- a/src/data_designer/config/default_model_settings.py +++ b/src/data_designer/config/default_model_settings.py @@ -6,7 +6,7 @@ import os from functools import lru_cache from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal from data_designer.config.models import ( ChatCompletionInferenceParams, @@ -85,7 +85,7 @@ def get_default_providers() -> list[ModelProvider]: return [] -def get_default_provider_name() -> Optional[str]: +def get_default_provider_name() -> str | None: return _get_default_providers_file_content(MODEL_PROVIDERS_FILE_PATH).get("default") diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index bb08fb58..24aadc08 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Generic, List, Literal, Optional, TypeVar, Union +from typing import Any, Generic, Literal, TypeVar import numpy as np from pydantic import BaseModel, Field, field_validator, model_validator @@ -74,7 +74,7 @@ class ImageContext(ModalityContext): """ modality: Modality = Modality.IMAGE - image_format: Optional[ImageFormat] = None + image_format: ImageFormat | None = None def get_context(self, record: dict) -> dict[str, Any]: """Get the context for the image modality. @@ -122,8 +122,8 @@ class ManualDistributionParams(ConfigBase): weights: Optional list of weights for each value. If not provided, all values have equal probability. """ - values: List[float] = Field(min_length=1) - weights: Optional[List[float]] = None + values: list[float] = Field(min_length=1) + weights: list[float] | None = None @model_validator(mode="after") def _normalize_weights(self) -> Self: @@ -149,7 +149,7 @@ class ManualDistribution(Distribution[ManualDistributionParams]): params: Distribution parameters (values, weights). """ - distribution_type: Optional[DistributionType] = "manual" + distribution_type: DistributionType | None = "manual" params: ManualDistributionParams def sample(self) -> float: @@ -190,7 +190,7 @@ class UniformDistribution(Distribution[UniformDistributionParams]): params: Distribution parameters (low, high). """ - distribution_type: Optional[DistributionType] = "uniform" + distribution_type: DistributionType | None = "uniform" params: UniformDistributionParams def sample(self) -> float: @@ -202,7 +202,7 @@ def sample(self) -> float: return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0]) -DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution] +DistributionT: TypeAlias = UniformDistribution | ManualDistribution class GenerationType(str, Enum): @@ -222,8 +222,8 @@ class BaseInferenceParams(ConfigBase, ABC): generation_type: GenerationType max_parallel_requests: int = Field(default=4, ge=1) - timeout: Optional[int] = Field(default=None, ge=1) - extra_body: Optional[dict[str, Any]] = None + timeout: int | None = Field(default=None, ge=1) + extra_body: dict[str, Any] | None = None @property def generate_kwargs(self) -> dict[str, Any]: @@ -282,9 +282,9 @@ class ChatCompletionInferenceParams(BaseInferenceParams): """ generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION - temperature: Optional[Union[float, DistributionT]] = None - top_p: Optional[Union[float, DistributionT]] = None - max_tokens: Optional[int] = Field(default=None, ge=1) + temperature: float | DistributionT | None = None + top_p: float | DistributionT | None = None + max_tokens: int | None = Field(default=None, ge=1) @property def generate_kwargs(self) -> dict[str, Any]: @@ -319,7 +319,7 @@ def _validate_top_p(self) -> Self: def _run_validation( self, - value: Union[float, DistributionT, None], + value: float | DistributionT | None, param_name: str, min_value: float, max_value: float, @@ -383,10 +383,10 @@ class EmbeddingInferenceParams(BaseInferenceParams): generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING encoding_format: Literal["float", "base64"] = "float" - dimensions: Optional[int] = None + dimensions: int | None = None @property - def generate_kwargs(self) -> dict[str, Union[float, int]]: + def generate_kwargs(self) -> dict[str, float | int]: result = super().generate_kwargs if self.encoding_format is not None: result["encoding_format"] = self.encoding_format @@ -395,7 +395,7 @@ def generate_kwargs(self) -> dict[str, Union[float, int]]: return result -InferenceParamsT: TypeAlias = Union[ChatCompletionInferenceParams, EmbeddingInferenceParams, InferenceParameters] +InferenceParamsT: TypeAlias = ChatCompletionInferenceParams | EmbeddingInferenceParams | InferenceParameters class ModelConfig(ConfigBase): @@ -412,7 +412,7 @@ class ModelConfig(ConfigBase): alias: str model: str inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams) - provider: Optional[str] = None + provider: str | None = None @property def generation_type(self) -> GenerationType: @@ -446,11 +446,11 @@ class ModelProvider(ConfigBase): name: str endpoint: str provider_type: str = "openai" - api_key: Optional[str] = None - extra_body: Optional[dict[str, Any]] = None + api_key: str | None = None + extra_body: dict[str, Any] | None = None -def load_model_configs(model_configs: Union[list[ModelConfig], str, Path]) -> list[ModelConfig]: +def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]: if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs): return model_configs json_config = smart_load_yaml(model_configs) diff --git a/src/data_designer/config/preview_results.py b/src/data_designer/config/preview_results.py index ba983b5f..e132a84a 100644 --- a/src/data_designer/config/preview_results.py +++ b/src/data_designer/config/preview_results.py @@ -3,8 +3,6 @@ from __future__ import annotations -from typing import Optional, Union - import pandas as pd from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults @@ -17,9 +15,9 @@ def __init__( self, *, config_builder: DataDesignerConfigBuilder, - dataset: Optional[pd.DataFrame] = None, - analysis: Optional[DatasetProfilerResults] = None, - processor_artifacts: Optional[dict[str, Union[list[str], str]]] = None, + dataset: pd.DataFrame | None = None, + analysis: DatasetProfilerResults | None = None, + processor_artifacts: dict[str, list[str] | str] | None = None, ): """Creates a new instance with results from a Data Designer preview run. @@ -29,7 +27,7 @@ def __init__( analysis: Analysis of the preview run. processor_artifacts: Artifacts generated by the processors. """ - self.dataset: Optional[pd.DataFrame] = dataset - self.analysis: Optional[DatasetProfilerResults] = analysis - self.processor_artifacts: Optional[dict[str, Union[list[str], str]]] = processor_artifacts + self.dataset: pd.DataFrame | None = dataset + self.analysis: DatasetProfilerResults | None = analysis + self.processor_artifacts: dict[str, list[str] | str] | None = processor_artifacts self._config_builder = config_builder diff --git a/src/data_designer/config/processors.py b/src/data_designer/config/processors.py index 17d2ff7b..cb45d94d 100644 --- a/src/data_designer/config/processors.py +++ b/src/data_designer/config/processors.py @@ -4,7 +4,7 @@ import json from abc import ABC from enum import Enum -from typing import Any, Literal, Union +from typing import Any, Literal from pydantic import Field, field_validator from typing_extensions import TypeAlias @@ -143,7 +143,4 @@ def validate_template(cls, v: dict[str, Any]) -> dict[str, Any]: return v -ProcessorConfigT: TypeAlias = Union[ - DropColumnsProcessorConfig, - SchemaTransformProcessorConfig, -] +ProcessorConfigT: TypeAlias = DropColumnsProcessorConfig | SchemaTransformProcessorConfig diff --git a/src/data_designer/config/sampler_constraints.py b/src/data_designer/config/sampler_constraints.py index e6ea65c0..fb048293 100644 --- a/src/data_designer/config/sampler_constraints.py +++ b/src/data_designer/config/sampler_constraints.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Union from typing_extensions import TypeAlias @@ -48,4 +47,4 @@ def constraint_type(self) -> ConstraintType: return ConstraintType.COLUMN_INEQUALITY -ColumnConstraintT: TypeAlias = Union[ScalarInequalityConstraint, ColumnInequalityConstraint] +ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint diff --git a/src/data_designer/config/sampler_params.py b/src/data_designer/config/sampler_params.py index 51ba3058..21ecdd9e 100644 --- a/src/data_designer/config/sampler_params.py +++ b/src/data_designer/config/sampler_params.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Literal, Optional, Union +from typing import Literal import pandas as pd from pydantic import Field, field_validator, model_validator @@ -54,12 +54,12 @@ class CategorySamplerParams(ConfigBase): Larger weights result in higher sampling probability for the corresponding value. """ - values: list[Union[str, int, float]] = Field( + values: list[str | int | float] = Field( ..., min_length=1, description="List of possible categorical values that can be sampled from.", ) - weights: Optional[list[float]] = Field( + weights: list[float] | None = Field( default=None, description=( "List of unnormalized probability weights to assigned to each value, in order. " @@ -134,7 +134,7 @@ class SubcategorySamplerParams(ConfigBase): """ category: str = Field(..., description="Name of parent category to this subcategory.") - values: dict[str, list[Union[str, int, float]]] = Field( + values: dict[str, list[str | int | float]] = Field( ..., description="Mapping from each value of parent category to a list of subcategory values.", ) @@ -214,7 +214,7 @@ class UUIDSamplerParams(ConfigBase): lowercase UUIDs. """ - prefix: Optional[str] = Field(default=None, description="String prepended to the front of the UUID.") + prefix: str | None = Field(default=None, description="String prepended to the front of the UUID.") short_form: bool = Field( default=False, description="If true, all UUIDs sampled will be truncated at 8 characters.", @@ -259,7 +259,7 @@ class ScipySamplerParams(ConfigBase): ..., description="Parameters of the scipy.stats distribution given in `dist_name`.", ) - decimal_places: Optional[int] = Field( + decimal_places: int | None = Field( default=None, description="Number of decimal places to round the sampled values to." ) sampler_type: Literal[SamplerType.SCIPY] = SamplerType.SCIPY @@ -356,7 +356,7 @@ class GaussianSamplerParams(ConfigBase): mean: float = Field(..., description="Mean of the Gaussian distribution") stddev: float = Field(..., description="Standard deviation of the Gaussian distribution") - decimal_places: Optional[int] = Field( + decimal_places: int | None = Field( default=None, description="Number of decimal places to round the sampled values to." ) sampler_type: Literal[SamplerType.GAUSSIAN] = SamplerType.GAUSSIAN @@ -398,7 +398,7 @@ class UniformSamplerParams(ConfigBase): low: float = Field(..., description="Lower bound of the uniform distribution, inclusive.") high: float = Field(..., description="Upper bound of the uniform distribution, inclusive.") - decimal_places: Optional[int] = Field( + decimal_places: int | None = Field( default=None, description="Number of decimal places to round the sampled values to." ) sampler_type: Literal[SamplerType.UNIFORM] = SamplerType.UNIFORM @@ -447,11 +447,11 @@ class PersonSamplerParams(ConfigBase): f"{', '.join(LOCALES_WITH_MANAGED_DATASETS)}." ), ) - sex: Optional[SexT] = Field( + sex: SexT | None = Field( default=None, description="If specified, then only synthetic people of the specified sex will be sampled.", ) - city: Optional[Union[str, list[str]]] = Field( + city: str | list[str] | None = Field( default=None, description="If specified, then only synthetic people from these cities will be sampled.", ) @@ -461,7 +461,7 @@ class PersonSamplerParams(ConfigBase): min_length=2, max_length=2, ) - select_field_values: Optional[dict[str, list[str]]] = Field( + select_field_values: dict[str, list[str]] | None = Field( default=None, description=( "Sample synthetic people with the specified field values. This is meant to be a flexible argument for " @@ -529,11 +529,11 @@ class PersonFromFakerSamplerParams(ConfigBase): "that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ..." ), ) - sex: Optional[SexT] = Field( + sex: SexT | None = Field( default=None, description="If specified, then only synthetic people of the specified sex will be sampled.", ) - city: Optional[Union[str, list[str]]] = Field( + city: str | list[str] | None = Field( default=None, description="If specified, then only synthetic people from these cities will be sampled.", ) @@ -585,22 +585,22 @@ def _validate_locale(cls, value: str) -> str: return value -SamplerParamsT: TypeAlias = Union[ - SubcategorySamplerParams, - CategorySamplerParams, - DatetimeSamplerParams, - PersonSamplerParams, - PersonFromFakerSamplerParams, - TimeDeltaSamplerParams, - UUIDSamplerParams, - BernoulliSamplerParams, - BernoulliMixtureSamplerParams, - BinomialSamplerParams, - GaussianSamplerParams, - PoissonSamplerParams, - UniformSamplerParams, - ScipySamplerParams, -] +SamplerParamsT: TypeAlias = ( + SubcategorySamplerParams + | CategorySamplerParams + | DatetimeSamplerParams + | PersonSamplerParams + | PersonFromFakerSamplerParams + | TimeDeltaSamplerParams + | UUIDSamplerParams + | BernoulliSamplerParams + | BernoulliMixtureSamplerParams + | BinomialSamplerParams + | GaussianSamplerParams + | PoissonSamplerParams + | UniformSamplerParams + | ScipySamplerParams +) def is_numerical_sampler_type(sampler_type: SamplerType) -> bool: diff --git a/src/data_designer/config/seed.py b/src/data_designer/config/seed.py index 0012fb8f..a49f73e2 100644 --- a/src/data_designer/config/seed.py +++ b/src/data_designer/config/seed.py @@ -3,7 +3,6 @@ from abc import ABC from enum import Enum -from typing import Optional, Union from pydantic import Field, field_validator, model_validator from typing_extensions import Self @@ -112,7 +111,7 @@ class SeedConfig(ConfigBase): dataset: str sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED - selection_strategy: Optional[Union[IndexRange, PartitionBlock]] = None + selection_strategy: IndexRange | PartitionBlock | None = None class SeedDatasetReference(ABC, ConfigBase): diff --git a/src/data_designer/config/utils/code_lang.py b/src/data_designer/config/utils/code_lang.py index c0621d36..4f1af4c8 100644 --- a/src/data_designer/config/utils/code_lang.py +++ b/src/data_designer/config/utils/code_lang.py @@ -4,7 +4,6 @@ from __future__ import annotations from enum import Enum -from typing import Union class CodeLang(str, Enum): @@ -26,17 +25,17 @@ class CodeLang(str, Enum): SQL_ANSI = "sql:ansi" @staticmethod - def parse(value: Union[str, CodeLang]) -> tuple[str, Union[str, None]]: + def parse(value: str | CodeLang) -> tuple[str, str | None]: value = value.value if isinstance(value, CodeLang) else value split_vals = value.split(":") return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None) @staticmethod - def parse_lang(value: Union[str, CodeLang]) -> str: + def parse_lang(value: str | CodeLang) -> str: return CodeLang.parse(value)[0] @staticmethod - def parse_dialect(value: Union[str, CodeLang]) -> Union[str, None]: + def parse_dialect(value: str | CodeLang) -> str | None: return CodeLang.parse(value)[1] @staticmethod @@ -58,7 +57,7 @@ def supported_values() -> set[str]: ########################################################## -def code_lang_to_syntax_lexer(code_lang: Union[CodeLang, str]) -> str: +def code_lang_to_syntax_lexer(code_lang: CodeLang | str) -> str: """Convert the code language to a syntax lexer for Pygments. Reference: https://pygments.org/docs/lexers/ diff --git a/src/data_designer/config/utils/io_helpers.py b/src/data_designer/config/utils/io_helpers.py index a1ade7cc..57a0c9c2 100644 --- a/src/data_designer/config/utils/io_helpers.py +++ b/src/data_designer/config/utils/io_helpers.py @@ -8,7 +8,7 @@ from decimal import Decimal from numbers import Number from pathlib import Path -from typing import Any, Union +from typing import Any import numpy as np import pandas as pd @@ -128,7 +128,7 @@ def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None: dataframe.to_json(file_path, orient="records", lines=True) -def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool = True) -> Path: +def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path: """Validate that a dataset file path has a valid extension and optionally exists. Args: @@ -165,7 +165,7 @@ def validate_path_contains_files_of_type(path: str | Path, file_extension: str) raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.") -def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFrame: +def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame: """Load a dataframe from file if a path is given, otherwise return the dataframe. Args: @@ -197,7 +197,7 @@ def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFr raise ValueError(f"Unsupported file format: {dataframe}") -def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict: +def smart_load_yaml(yaml_in: str | Path | dict) -> dict: """Return the yaml config as a dict given flexible input types. Args: @@ -227,7 +227,7 @@ def smart_load_yaml(yaml_in: Union[str, Path, dict]) -> dict: return yaml_out -def serialize_data(data: Union[dict, list, str, Number], **kwargs) -> str: +def serialize_data(data: dict | list | str | Number, **kwargs) -> str: if isinstance(data, dict): return json.dumps(data, ensure_ascii=False, default=_convert_to_serializable, **kwargs) elif isinstance(data, list): diff --git a/src/data_designer/config/utils/misc.py b/src/data_designer/config/utils/misc.py index 748455e8..302668b4 100644 --- a/src/data_designer/config/utils/misc.py +++ b/src/data_designer/config/utils/misc.py @@ -5,7 +5,6 @@ import json from contextlib import contextmanager -from typing import Optional, Union from jinja2 import TemplateSyntaxError, meta from jinja2.sandbox import ImmutableSandboxedEnvironment @@ -58,9 +57,7 @@ def get_prompt_template_keywords(template: str) -> set[str]: return keywords -def json_indent_list_of_strings( - column_names: list[str], *, indent: Optional[Union[int, str]] = None -) -> Optional[Union[list[str], str]]: +def json_indent_list_of_strings(column_names: list[str], *, indent: int | str | None = None) -> list[str] | str | None: """Convert a list of column names to a JSON string if the list is long. This function helps keep Data Designer's __repr__ output clean and readable. diff --git a/src/data_designer/config/utils/numerical_helpers.py b/src/data_designer/config/utils/numerical_helpers.py index fbdf8c25..558bf516 100644 --- a/src/data_designer/config/utils/numerical_helpers.py +++ b/src/data_designer/config/utils/numerical_helpers.py @@ -3,7 +3,7 @@ import numbers from numbers import Number -from typing import Any, Type +from typing import Any from data_designer.config.utils.constants import REPORTING_PRECISION @@ -18,7 +18,7 @@ def is_float(val: Any) -> bool: def prepare_number_for_reporting( value: Number, - target_type: Type[Number], + target_type: type[Number], precision: int = REPORTING_PRECISION, ) -> Number: """Ensure native python types and round to `precision` decimal digits.""" diff --git a/src/data_designer/config/utils/type_helpers.py b/src/data_designer/config/utils/type_helpers.py index 321462b9..d1174a51 100644 --- a/src/data_designer/config/utils/type_helpers.py +++ b/src/data_designer/config/utils/type_helpers.py @@ -3,7 +3,7 @@ import inspect from enum import Enum -from typing import Any, Literal, Type, get_args, get_origin +from typing import Any, Literal, get_args, get_origin from pydantic import BaseModel @@ -56,7 +56,7 @@ def create_str_enum_from_discriminated_type_union( return StrEnum(enum_name, {v.replace("-", "_").upper(): v for v in set(discriminator_field_values)}) -def get_sampler_params() -> dict[str, Type[BaseModel]]: +def get_sampler_params() -> dict[str, type[BaseModel]]: """Returns a dictionary of sampler parameter classes.""" params_cls_list = [ params_cls @@ -83,7 +83,7 @@ def get_sampler_params() -> dict[str, Type[BaseModel]]: return params_cls_dict -def resolve_string_enum(enum_instance: Any, enum_type: Type[Enum]) -> Enum: +def resolve_string_enum(enum_instance: Any, enum_type: type[Enum]) -> Enum: if not issubclass(enum_type, Enum): raise InvalidEnumValueError(f"🛑 `enum_type` must be a subclass of Enum. You provided: {enum_type}") invalid_enum_value_error = InvalidEnumValueError( diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/config/utils/validation.py index dc1ca2e3..4e951e65 100644 --- a/src/data_designer/config/utils/validation.py +++ b/src/data_designer/config/utils/validation.py @@ -5,7 +5,6 @@ from enum import Enum from string import Formatter -from typing import Optional from jinja2 import meta from jinja2.sandbox import ImmutableSandboxedEnvironment @@ -45,7 +44,7 @@ class ViolationLevel(str, Enum): class Violation(BaseModel): - column: Optional[str] = None + column: str | None = None type: ViolationType message: str level: ViolationLevel diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py index 85a230a9..308e39e4 100644 --- a/src/data_designer/config/utils/visualization.py +++ b/src/data_designer/config/utils/visualization.py @@ -8,7 +8,7 @@ from collections import OrderedDict from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -36,11 +36,11 @@ console = Console() -def get_nvidia_api_key() -> Optional[str]: +def get_nvidia_api_key() -> str | None: return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME) -def get_openai_api_key() -> Optional[str]: +def get_openai_api_key() -> str | None: return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME) @@ -77,12 +77,12 @@ def _has_processor_artifacts(self) -> bool: def display_sample_record( self, - index: Optional[int] = None, + index: int | None = None, *, hide_seed_columns: bool = False, syntax_highlighting_theme: str = "dracula", - background_color: Optional[str] = None, - processors_to_display: Optional[list[str]] = None, + background_color: str | None = None, + processors_to_display: list[str] | None = None, ) -> None: """Display a sample record from the Data Designer dataset preview. @@ -134,11 +134,11 @@ def display_sample_record( def create_rich_histogram_table( - data: dict[str, Union[int, float]], + data: dict[str, int | float], column_names: tuple[int, int], name_style: str = ColorPalette.BLUE.value, value_style: str = ColorPalette.TEAL.value, - title: Optional[str] = None, + title: str | None = None, **kwargs, ) -> Table: table = Table(title=title, **kwargs) @@ -154,12 +154,12 @@ def create_rich_histogram_table( def display_sample_record( - record: Union[dict, pd.Series, pd.DataFrame], + record: dict | pd.Series | pd.DataFrame, config_builder: DataDesignerConfigBuilder, - processor_data_to_display: Optional[dict[str, Union[list[str], str]]] = None, - background_color: Optional[str] = None, + processor_data_to_display: dict[str, list[str] | str] | None = None, + background_color: str | None = None, syntax_highlighting_theme: str = "dracula", - record_index: Optional[int] = None, + record_index: int | None = None, hide_seed_columns: bool = False, ): if isinstance(record, (dict, pd.Series)): @@ -286,7 +286,7 @@ def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> st def display_sampler_table( sampler_params: dict[SamplerType, ConfigBase], - title: Optional[str] = None, + title: str | None = None, ) -> None: table = Table(expand=True) table.add_column("Type") diff --git a/src/data_designer/config/validator_params.py b/src/data_designer/config/validator_params.py index 3aaff6d9..5944bbaa 100644 --- a/src/data_designer/config/validator_params.py +++ b/src/data_designer/config/validator_params.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Any, Optional, Union +from typing import Any from pydantic import Field, field_serializer, model_validator from typing_extensions import Self, TypeAlias @@ -51,7 +51,7 @@ class LocalCallableValidatorParams(ConfigBase): validation_function: Any = Field( description="Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data" ) - output_schema: Optional[dict[str, Any]] = Field( + output_schema: dict[str, Any] | None = Field( default=None, description="Expected schema for local callable validator's output" ) @@ -80,7 +80,7 @@ class RemoteValidatorParams(ConfigBase): """ endpoint_url: str = Field(description="URL of the remote endpoint") - output_schema: Optional[dict[str, Any]] = Field( + output_schema: dict[str, Any] | None = Field( default=None, description="Expected schema for remote validator's output" ) timeout: float = Field(default=30.0, gt=0, description="The timeout for the HTTP request") @@ -89,8 +89,4 @@ class RemoteValidatorParams(ConfigBase): max_parallel_requests: int = Field(default=4, ge=1, description="The maximum number of parallel requests to make") -ValidatorParamsT: TypeAlias = Union[ - CodeValidatorParams, - LocalCallableValidatorParams, - RemoteValidatorParams, -] +ValidatorParamsT: TypeAlias = CodeValidatorParams | LocalCallableValidatorParams | RemoteValidatorParams diff --git a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py index e117f2c1..3a411f23 100644 --- a/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py +++ b/src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py @@ -5,7 +5,6 @@ import logging import random -from typing import Union from data_designer.config.analysis.column_profilers import ( JudgeScoreProfilerConfig, @@ -96,7 +95,7 @@ def _summarize_score_sample( name: str, sample: list[JudgeScoreSample], histogram: CategoricalHistogramData, - distribution: Union[CategoricalDistribution, NumericalDistribution, MissingValue], + distribution: CategoricalDistribution | NumericalDistribution | MissingValue, distribution_type: ColumnDistributionType, ) -> JudgeScoreSummary: if isinstance(distribution, MissingValue) or not sample: diff --git a/src/data_designer/engine/analysis/column_statistics.py b/src/data_designer/engine/analysis/column_statistics.py index e3471ae9..8795b495 100644 --- a/src/data_designer/engine/analysis/column_statistics.py +++ b/src/data_designer/engine/analysis/column_statistics.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging -from typing import Any, Type, TypeAlias, Union +from typing import Any, TypeAlias import pandas as pd from pydantic import BaseModel @@ -41,7 +41,7 @@ def df(self) -> pd.DataFrame: return self.column_config_with_df.df @property - def column_statistics_type(self) -> Type[ColumnStatisticsT]: + def column_statistics_type(self) -> type[ColumnStatisticsT]: return DEFAULT_COLUMN_STATISTICS_MAP.get(self.column_config.column_type, GeneralColumnStatistics) def calculate(self) -> Self: @@ -115,17 +115,17 @@ def calculate_validation_column_info(self) -> dict[str, Any]: class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ... -ColumnStatisticsCalculatorT: TypeAlias = Union[ - ExpressionColumnStatisticsCalculator, - ValidationColumnStatisticsCalculator, - GeneralColumnStatisticsCalculator, - LLMCodeColumnStatisticsCalculator, - LLMJudgedColumnStatisticsCalculator, - LLMStructuredColumnStatisticsCalculator, - LLMTextColumnStatisticsCalculator, - SamplerColumnStatisticsCalculator, - SeedDatasetColumnStatisticsCalculator, -] +ColumnStatisticsCalculatorT: TypeAlias = ( + ExpressionColumnStatisticsCalculator + | ValidationColumnStatisticsCalculator + | GeneralColumnStatisticsCalculator + | LLMCodeColumnStatisticsCalculator + | LLMJudgedColumnStatisticsCalculator + | LLMStructuredColumnStatisticsCalculator + | LLMTextColumnStatisticsCalculator + | SamplerColumnStatisticsCalculator + | SeedDatasetColumnStatisticsCalculator +) DEFAULT_COLUMN_STATISTICS_CALCULATOR_MAP = { DataDesignerColumnType.EXPRESSION: ExpressionColumnStatisticsCalculator, DataDesignerColumnType.VALIDATION: ValidationColumnStatisticsCalculator, diff --git a/src/data_designer/engine/analysis/utils/judge_score_processing.py b/src/data_designer/engine/analysis/utils/judge_score_processing.py index 694f4142..7c0f44bf 100644 --- a/src/data_designer/engine/analysis/utils/judge_score_processing.py +++ b/src/data_designer/engine/analysis/utils/judge_score_processing.py @@ -3,7 +3,7 @@ import logging from collections import defaultdict -from typing import Any, Optional, Union +from typing import Any import pandas as pd @@ -21,7 +21,7 @@ def extract_judge_score_distributions( column_config: LLMJudgeColumnConfig, df: pd.DataFrame -) -> Union[JudgeScoreDistributions, MissingValue]: +) -> JudgeScoreDistributions | MissingValue: scores = defaultdict(list) reasoning = defaultdict(list) @@ -79,10 +79,10 @@ def extract_judge_score_distributions( def sample_scores_and_reasoning( - scores: list[Union[int, str]], + scores: list[int | str], reasoning: list[str], num_samples: int, - random_seed: Optional[int] = None, + random_seed: int | None = None, ) -> list[JudgeScoreSample]: if len(scores) != len(reasoning): raise ValueError("scores and reasoning must have the same length") diff --git a/src/data_designer/engine/column_generators/utils/judge_score_factory.py b/src/data_designer/engine/column_generators/utils/judge_score_factory.py index df31d340..9154aaea 100644 --- a/src/data_designer/engine/column_generators/utils/judge_score_factory.py +++ b/src/data_designer/engine/column_generators/utils/judge_score_factory.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Type from pydantic import BaseModel, ConfigDict, Field, create_model @@ -19,7 +18,7 @@ class BaseJudgeResponse(BaseModel): reasoning: str = Field(..., description="Reasoning for the assigned score.") -def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str: +def _stringify_scoring(options: dict, enum_type: type[Enum]) -> str: """Convert score descriptions into a single text block.""" list_block = "\n".join( [SCORING_FORMAT.format(score=score, description=description) for score, description in options.items()] @@ -27,7 +26,7 @@ def _stringify_scoring(options: dict, enum_type: Type[Enum]) -> str: return SCORE_FIELD_DESCRIPTION_FORMAT.format(enum_name=enum_type.__name__, scoring=list_block) -def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]: +def create_judge_response_model(score: Score) -> type[BaseJudgeResponse]: """Create a JudgeResponse data type.""" enum_members = {} for option in score.options.keys(): @@ -46,8 +45,8 @@ def create_judge_response_model(score: Score) -> Type[BaseJudgeResponse]: def create_judge_structured_output_model( - judge_responses: list[Type[BaseJudgeResponse]], -) -> Type[BaseModel]: + judge_responses: list[type[BaseJudgeResponse]], +) -> type[BaseModel]: """Create a JudgeStructuredOutput class dynamically.""" return create_model( "JudgeStructuredOutput", diff --git a/src/data_designer/engine/configurable_task.py b/src/data_designer/engine/configurable_task.py index ee2dff46..38b2748f 100644 --- a/src/data_designer/engine/configurable_task.py +++ b/src/data_designer/engine/configurable_task.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Generic, Type, TypeVar, get_origin +from typing import Generic, TypeVar, get_origin import pandas as pd @@ -30,7 +30,7 @@ def __init__(self, config: TaskConfigT, *, resource_provider: ResourceProvider | self._initialize() @classmethod - def get_config_type(cls) -> Type[TaskConfigT]: + def get_config_type(cls) -> type[TaskConfigT]: for base in cls.__orig_bases__: if hasattr(base, "__args__") and len(base.__args__) == 1: arg = base.__args__[0] diff --git a/src/data_designer/engine/dataset_builders/artifact_storage.py b/src/data_designer/engine/dataset_builders/artifact_storage.py index 152ac13d..ac5cc254 100644 --- a/src/data_designer/engine/dataset_builders/artifact_storage.py +++ b/src/data_designer/engine/dataset_builders/artifact_storage.py @@ -7,7 +7,6 @@ from datetime import datetime from functools import cached_property from pathlib import Path -from typing import Union import pandas as pd from pydantic import BaseModel, field_validator, model_validator @@ -77,7 +76,7 @@ def processors_outputs_path(self) -> Path: return self.base_dataset_path / self.processors_outputs_folder_name @field_validator("artifact_path") - def validate_artifact_path(cls, v: Union[Path, str]) -> Path: + def validate_artifact_path(cls, v: Path | str) -> Path: v = Path(v) if not v.is_dir(): raise ArtifactStorageError("Artifact path must exist and be a directory") diff --git a/src/data_designer/engine/dataset_builders/utils/concurrency.py b/src/data_designer/engine/dataset_builders/utils/concurrency.py index be04378d..95e8f3fc 100644 --- a/src/data_designer/engine/dataset_builders/utils/concurrency.py +++ b/src/data_designer/engine/dataset_builders/utils/concurrency.py @@ -8,7 +8,7 @@ import logging from concurrent.futures import Future, ThreadPoolExecutor from threading import Lock, Semaphore -from typing import Any, Optional, Protocol +from typing import Any, Protocol from pydantic import BaseModel, Field @@ -46,13 +46,13 @@ def is_error_rate_exceeded(self, window: int) -> bool: class CallbackWithContext(Protocol): """Executor callback functions must accept a context kw argument.""" - def __call__(self, result: Any, *, context: Optional[dict] = None) -> Any: ... + def __call__(self, result: Any, *, context: dict | None = None) -> Any: ... class ErrorCallbackWithContext(Protocol): """Error callbacks take the Exception instance and context.""" - def __call__(self, exc: Exception, *, context: Optional[dict] = None) -> Any: ... + def __call__(self, exc: Exception, *, context: dict | None = None) -> Any: ... class ConcurrentThreadExecutor: @@ -92,8 +92,8 @@ def __init__( *, max_workers: int, column_name: str, - result_callback: Optional[CallbackWithContext] = None, - error_callback: Optional[ErrorCallbackWithContext] = None, + result_callback: CallbackWithContext | None = None, + error_callback: ErrorCallbackWithContext | None = None, shutdown_error_rate: float = 0.50, shutdown_error_window: int = 10, ): @@ -136,7 +136,7 @@ def _raise_task_error(self): ) ) - def submit(self, fn, *args, context: Optional[dict] = None, **kwargs) -> None: + def submit(self, fn, *args, context: dict | None = None, **kwargs) -> None: if self._executor is None: raise RuntimeError("Executor is not initialized, this class should be used as a context manager.") diff --git a/src/data_designer/engine/models/litellm_overrides.py b/src/data_designer/engine/models/litellm_overrides.py index 41208a0d..a5c4981f 100644 --- a/src/data_designer/engine/models/litellm_overrides.py +++ b/src/data_designer/engine/models/litellm_overrides.py @@ -5,7 +5,6 @@ import random import threading -from typing import Optional, Union import httpx import litellm @@ -90,7 +89,7 @@ def __init__( self._initial_retry_after_s = initial_retry_after_s self._jitter_pct = jitter_pct - def _extract_retry_delay_from_headers(self, e: Exception) -> Optional[Union[int, float]]: + def _extract_retry_delay_from_headers(self, e: Exception) -> int | float | None: """ Most of this code logic was extracted directly from the parent `Router`'s `_time_to_sleep_before_retry` function. Our override @@ -99,7 +98,7 @@ def _extract_retry_delay_from_headers(self, e: Exception) -> Optional[Union[int, return this info, we'll simply use that retry value returned here. """ - response_headers: Optional[httpx.Headers] = None + response_headers: httpx.Headers | None = None if hasattr(e, "response") and hasattr(e.response, "headers"): # type: ignore response_headers = e.response.headers # type: ignore if hasattr(e, "litellm_response_headers"): @@ -119,9 +118,9 @@ def _time_to_sleep_before_retry( e: Exception, remaining_retries: int, num_retries: int, - healthy_deployments: Optional[list] = None, - all_deployments: Optional[list] = None, - ) -> Union[int, float]: + healthy_deployments: list | None = None, + all_deployments: list | None = None, + ) -> int | float: """ Implements exponential backoff for retries. diff --git a/src/data_designer/engine/models/parsers/errors.py b/src/data_designer/engine/models/parsers/errors.py index cf0fd411..087e8a83 100644 --- a/src/data_designer/engine/models/parsers/errors.py +++ b/src/data_designer/engine/models/parsers/errors.py @@ -1,8 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Optional - class ParserException(Exception): """Identifies errors resulting from generic parser errors. @@ -12,7 +10,7 @@ class ParserException(Exception): attempted to parse. """ - source: Optional[str] + source: str | None @staticmethod def _log_format(source: str) -> str: @@ -24,7 +22,7 @@ def _log_format(source: str) -> str: # return f"{source}" return "" - def __init__(self, msg: Optional[str] = None, source: Optional[str] = None): + def __init__(self, msg: str | None = None, source: str | None = None): msg = "" if msg is None else msg.strip() if source is not None: diff --git a/src/data_designer/engine/models/parsers/parser.py b/src/data_designer/engine/models/parsers/parser.py index e147cfdc..50a21494 100644 --- a/src/data_designer/engine/models/parsers/parser.py +++ b/src/data_designer/engine/models/parsers/parser.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from functools import reduce -from typing import Optional import marko from lxml import etree @@ -105,8 +104,8 @@ def __call__(self, element: _Element) -> CodeBlock: def __init__( self, - tag_parsers: Optional[dict[str, TagParser]] = None, - postprocessors: Optional[list[PostProcessor]] = None, + tag_parsers: dict[str, TagParser] | None = None, + postprocessors: list[PostProcessor] | None = None, ): """ Initializes the LLMResponseParser with optional tag parsers and post-processors. diff --git a/src/data_designer/engine/models/parsers/postprocessors.py b/src/data_designer/engine/models/parsers/postprocessors.py index d7959505..1cce5290 100644 --- a/src/data_designer/engine/models/parsers/postprocessors.py +++ b/src/data_designer/engine/models/parsers/postprocessors.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Type import json_repair from pydantic import BaseModel, ValidationError @@ -60,12 +59,12 @@ def deserialize_json_code( class RealizePydanticTypes: - types: list[Type[BaseModel]] + types: list[type[BaseModel]] - def __init__(self, types: list[Type[BaseModel]]): + def __init__(self, types: list[type[BaseModel]]): self.types = types - def _fit_types(self, obj: dict) -> Optional[BaseModel]: + def _fit_types(self, obj: dict) -> BaseModel | None: final_obj = None for t in self.types: diff --git a/src/data_designer/engine/models/parsers/types.py b/src/data_designer/engine/models/parsers/types.py index 17be38a2..aacb54b1 100644 --- a/src/data_designer/engine/models/parsers/types.py +++ b/src/data_designer/engine/models/parsers/types.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional, Protocol, Type, runtime_checkable +from typing import Any, Protocol, runtime_checkable from lxml.etree import _Element from pydantic import BaseModel, Field @@ -30,7 +30,7 @@ def tail(self, n: int) -> Self: out.parsed = out.parsed[-n:] return out - def filter(self, block_types: list[Type[BaseModel]]) -> Self: + def filter(self, block_types: list[type[BaseModel]]) -> Self: out = self.model_copy() out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))] return out @@ -44,7 +44,7 @@ class TagParser(Protocol): element, do some computation, and return some kind of structured output, represented as a subclass of Pydantic `BaseModel`. This protocol implementation can cover both classes as well - as curried fuctions as parsers (e.g. `partial`). + as curried functions as parsers (e.g. `partial`). """ def __call__(self, element: _Element) -> BaseModel: ... @@ -69,7 +69,7 @@ class TextBlock(BaseModel): class CodeBlock(BaseModel): code: str - code_lang: Optional[str] = None + code_lang: str | None = None class StructuredDataBlock(BaseModel): diff --git a/src/data_designer/engine/processing/ginja/ast.py b/src/data_designer/engine/processing/ginja/ast.py index 2d1fecb3..9171365f 100644 --- a/src/data_designer/engine/processing/ginja/ast.py +++ b/src/data_designer/engine/processing/ginja/ast.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from collections import deque -from typing import Optional, Type from jinja2 import nodes as j_nodes @@ -33,7 +32,7 @@ def ast_max_depth(node: j_nodes.Node) -> int: return max_depth -def ast_descendant_count(ast: j_nodes.Node, only_type: Optional[Type[j_nodes.Node]] = None) -> int: +def ast_descendant_count(ast: j_nodes.Node, only_type: type[j_nodes.Node] | None = None) -> int: """Count the number of nodes which descend from the given node. Args: diff --git a/src/data_designer/engine/processing/utils.py b/src/data_designer/engine/processing/utils.py index 5d42c40e..00d1d0ca 100644 --- a/src/data_designer/engine/processing/utils.py +++ b/src/data_designer/engine/processing/utils.py @@ -5,7 +5,7 @@ import json import logging import re -from typing import Any, TypeVar, Union, overload +from typing import Any, TypeVar, overload import pandas as pd @@ -27,7 +27,7 @@ def concat_datasets(datasets: list[pd.DataFrame]) -> pd.DataFrame: # Overloads to help static type checker better understand # the input/output types of the deserialize_json_values function. @overload -def deserialize_json_values(data: str) -> Union[dict[str, Any], list[Any], Any]: ... +def deserialize_json_values(data: str) -> dict[str, Any] | list[Any] | Any: ... @overload diff --git a/src/data_designer/engine/registry/base.py b/src/data_designer/engine/registry/base.py index a22b4a50..2e8b069e 100644 --- a/src/data_designer/engine/registry/base.py +++ b/src/data_designer/engine/registry/base.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import threading -from typing import Any, Generic, Type, TypeVar +from typing import Any, Generic, TypeVar from data_designer.config.base import ConfigBase from data_designer.config.utils.type_helpers import StrEnum @@ -16,14 +16,14 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]): # registered type name -> type - _registry: dict[EnumNameT, Type[TaskT]] = {} + _registry: dict[EnumNameT, type[TaskT]] = {} # type -> registered type name - _reverse_registry: dict[Type[TaskT], EnumNameT] = {} + _reverse_registry: dict[type[TaskT], EnumNameT] = {} # registered type name -> config type - _config_registry: dict[EnumNameT, Type[TaskConfigT]] = {} + _config_registry: dict[EnumNameT, type[TaskConfigT]] = {} # config type -> registered type name - _reverse_config_registry: dict[Type[TaskConfigT], EnumNameT] = {} + _reverse_config_registry: dict[type[TaskConfigT], EnumNameT] = {} # all registries are singletons _instance = None @@ -33,8 +33,8 @@ class TaskRegistry(Generic[EnumNameT, TaskT, TaskConfigT]): def register( cls, name: EnumNameT, - task: Type[TaskT], - config: Type[TaskConfigT], + task: type[TaskT], + config: type[TaskConfigT], raise_on_collision: bool = False, ) -> None: if cls._has_been_registered(name): @@ -52,22 +52,22 @@ def register( cls._reverse_config_registry[config] = name @classmethod - def get_task_type(cls, name: EnumNameT) -> Type[TaskT]: + def get_task_type(cls, name: EnumNameT) -> type[TaskT]: cls._raise_if_not_registered(name, cls._registry) return cls._registry[name] @classmethod - def get_config_type(cls, name: EnumNameT) -> Type[TaskConfigT]: + def get_config_type(cls, name: EnumNameT) -> type[TaskConfigT]: cls._raise_if_not_registered(name, cls._config_registry) return cls._config_registry[name] @classmethod - def get_registered_name(cls, task: Type[TaskT]) -> EnumNameT: + def get_registered_name(cls, task: type[TaskT]) -> EnumNameT: cls._raise_if_not_registered(task, cls._reverse_registry) return cls._reverse_registry[task] @classmethod - def get_for_config_type(cls, config: Type[TaskConfigT]) -> Type[TaskT]: + def get_for_config_type(cls, config: type[TaskConfigT]) -> type[TaskT]: cls._raise_if_not_registered(config, cls._reverse_config_registry) name = cls._reverse_config_registry[config] return cls.get_task_type(name) @@ -77,7 +77,7 @@ def _has_been_registered(cls, name: EnumNameT) -> bool: return name in cls._registry @classmethod - def _raise_if_not_registered(cls, key: EnumNameT | Type[TaskT] | Type[TaskConfigT], mapping: dict) -> None: + def _raise_if_not_registered(cls, key: EnumNameT | type[TaskT] | type[TaskConfigT], mapping: dict) -> None: if not (isinstance(key, StrEnum) or isinstance(key, str)): cls._raise_if_not_type(key) if key not in mapping: diff --git a/src/data_designer/engine/sampling_gen/constraints.py b/src/data_designer/engine/sampling_gen/constraints.py index 053199e8..e7713049 100644 --- a/src/data_designer/engine/sampling_gen/constraints.py +++ b/src/data_designer/engine/sampling_gen/constraints.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Type import numpy as np import pandas as pd @@ -91,5 +90,5 @@ def check(self, dataframe: pd.DataFrame) -> NDArray[np.bool_]: } -def get_constraint_checker(constraint_type: ConstraintType) -> Type[ConstraintChecker]: +def get_constraint_checker(constraint_type: ConstraintType) -> type[ConstraintChecker]: return CONSTRAINT_TYPE_TO_CHECKER[ConstraintType(constraint_type)] diff --git a/src/data_designer/engine/sampling_gen/data_sources/base.py b/src/data_designer/engine/sampling_gen/data_sources/base.py index c1d3bd05..3758eefc 100644 --- a/src/data_designer/engine/sampling_gen/data_sources/base.py +++ b/src/data_designer/engine/sampling_gen/data_sources/base.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, Generic, Optional, Type, TypeVar, Union +from typing import Any, Generic, TypeVar import numpy as np import pandas as pd @@ -45,7 +45,7 @@ def postproc(series: pd.Series, convert_to: str) -> pd.Series: return series @staticmethod - def validate_data_conversion(convert_to: Optional[str]) -> None: + def validate_data_conversion(convert_to: str | None) -> None: pass @@ -71,7 +71,7 @@ def preproc(series: pd.Series, convert_to: str) -> pd.Series: return series @staticmethod - def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: + def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: if convert_to is not None: if convert_to == "int": series = series.round() @@ -79,18 +79,18 @@ def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: return series @staticmethod - def validate_data_conversion(convert_to: Optional[str]) -> None: + def validate_data_conversion(convert_to: str | None) -> None: if convert_to is not None and convert_to not in ["float", "int", "str"]: raise ValueError(f"Invalid `convert_to` value: {convert_to}. Must be one of: [float, int, str]") class DatetimeFormatMixin: @staticmethod - def preproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: + def preproc(series: pd.Series, convert_to: str | None) -> pd.Series: return series @staticmethod - def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: + def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: if convert_to is not None: return series.dt.strftime(convert_to) if series.dt.month.nunique() == 1: @@ -104,7 +104,7 @@ def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: return series.apply(lambda dt: dt.isoformat()).astype(str) @staticmethod - def validate_data_conversion(convert_to: Optional[str]) -> None: + def validate_data_conversion(convert_to: str | None) -> None: if convert_to is not None: try: pd.to_datetime(pd.to_datetime("2012-12-21").strftime(convert_to)) @@ -121,7 +121,7 @@ class DataSource(ABC, Generic[GenericParamsT]): def __init__( self, params: GenericParamsT, - random_state: Optional[RadomStateT] = None, + random_state: RadomStateT | None = None, **kwargs, ): self.rng = check_random_state(random_state) @@ -130,7 +130,7 @@ def __init__( self._validate() @classmethod - def get_param_type(cls) -> Type[GenericParamsT]: + def get_param_type(cls) -> type[GenericParamsT]: return cls.__orig_bases__[-1].__args__[0] @abstractmethod @@ -138,7 +138,7 @@ def inject_data_column( self, dataframe: pd.DataFrame, column_name: str, - index: Optional[list[int]] = None, + index: list[int] | None = None, ) -> pd.DataFrame: ... @staticmethod @@ -147,11 +147,11 @@ def preproc(series: pd.Series) -> pd.Series: ... @staticmethod @abstractmethod - def postproc(series: pd.Series, convert_to: Optional[str]) -> pd.Series: ... + def postproc(series: pd.Series, convert_to: str | None) -> pd.Series: ... @staticmethod @abstractmethod - def validate_data_conversion(convert_to: Optional[str]) -> None: ... + def validate_data_conversion(convert_to: str | None) -> None: ... def get_required_column_names(self) -> tuple[str, ...]: return tuple() @@ -182,7 +182,7 @@ def inject_data_column( self, dataframe: pd.DataFrame, column_name: str, - index: Optional[list[int]] = None, + index: list[int] | None = None, ) -> pd.DataFrame: index = slice(None) if index is None else index @@ -208,7 +208,7 @@ def sample(self, num_samples: int) -> NumpyArray1dT: ... class ScipyStatsSampler(Sampler[GenericParamsT], ABC): @property @abstractmethod - def distribution(self) -> Union[stats.rv_continuous, stats.rv_discrete]: ... + def distribution(self) -> stats.rv_continuous | stats.rv_discrete: ... def sample(self, num_samples: int) -> NumpyArray1dT: return self.distribution.rvs(size=num_samples, random_state=self.rng) diff --git a/src/data_designer/engine/sampling_gen/entities/phone_number.py b/src/data_designer/engine/sampling_gen/entities/phone_number.py index aac1ff4f..bb618939 100644 --- a/src/data_designer/engine/sampling_gen/entities/phone_number.py +++ b/src/data_designer/engine/sampling_gen/entities/phone_number.py @@ -3,7 +3,6 @@ import random from pathlib import Path -from typing import Optional import pandas as pd from pydantic import BaseModel, Field, field_validator @@ -13,7 +12,7 @@ ZIPCODE_POPULATION_MAP = dict(zip(ZIP_AREA_CODE_DATA["zipcode"], ZIP_AREA_CODE_DATA["count"])) -def get_area_code(zip_prefix: Optional[str] = None) -> str: +def get_area_code(zip_prefix: str | None = None) -> str: """ Sample an area code for the given ZIP code prefix, population-weighted. diff --git a/src/data_designer/engine/sampling_gen/people_gen.py b/src/data_designer/engine/sampling_gen/people_gen.py index b605fe66..67e81d25 100644 --- a/src/data_designer/engine/sampling_gen/people_gen.py +++ b/src/data_designer/engine/sampling_gen/people_gen.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from copy import deepcopy -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, TypeAlias import pandas as pd from faker import Faker @@ -27,7 +27,7 @@ from data_designer.engine.sampling_gen.schema import DataSchema -EngineT = Union[Faker, ManagedDatasetGenerator] +EngineT: TypeAlias = Faker | ManagedDatasetGenerator class PeopleGen(ABC): diff --git a/src/data_designer/engine/validators/base.py b/src/data_designer/engine/validators/base.py index 902b3429..22d3ec58 100644 --- a/src/data_designer/engine/validators/base.py +++ b/src/data_designer/engine/validators/base.py @@ -2,14 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Iterator, Optional +from typing import Iterator from pydantic import BaseModel, ConfigDict from typing_extensions import Self class ValidationOutput(BaseModel): - is_valid: Optional[bool] + is_valid: bool | None model_config = ConfigDict(extra="allow") diff --git a/src/data_designer/logging.py b/src/data_designer/logging.py index 85796c35..84f9f085 100644 --- a/src/data_designer/logging.py +++ b/src/data_designer/logging.py @@ -6,7 +6,7 @@ import sys from dataclasses import dataclass, field from pathlib import Path -from typing import TextIO, Union +from typing import TextIO from pythonjsonlogger import jsonlogger @@ -19,7 +19,7 @@ class LoggerConfig: @dataclass class OutputConfig: - destination: Union[TextIO, Path] + destination: TextIO | Path structured: bool diff --git a/src/data_designer/plugin_manager.py b/src/data_designer/plugin_manager.py index 891cc905..03eb603c 100644 --- a/src/data_designer/plugin_manager.py +++ b/src/data_designer/plugin_manager.py @@ -4,7 +4,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Type, TypeAlias +from typing import TYPE_CHECKING, TypeAlias from data_designer.plugins.plugin import PluginType from data_designer.plugins.registry import PluginRegistry @@ -37,7 +37,7 @@ def get_column_generator_plugin_if_exists(self, plugin_name: str) -> Plugin | No if self._plugin_registry.plugin_exists(plugin_name): return self._plugin_registry.get_plugin(plugin_name) - def get_plugin_column_types(self, enum_type: Type[Enum], required_resources: list[str] | None = None) -> list[Enum]: + def get_plugin_column_types(self, enum_type: type[Enum], required_resources: list[str] | None = None) -> list[Enum]: """Get a list of plugin column types. Args: @@ -56,7 +56,7 @@ def get_plugin_column_types(self, enum_type: Type[Enum], required_resources: lis type_list.append(enum_type(plugin.name)) return type_list - def inject_into_column_config_type_union(self, column_config_type: Type[TypeAlias]) -> Type[TypeAlias]: + def inject_into_column_config_type_union(self, column_config_type: type[TypeAlias]) -> type[TypeAlias]: """Inject plugins into the column config type. Args: diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py index 886a2252..6553e45e 100644 --- a/src/data_designer/plugins/plugin.py +++ b/src/data_designer/plugins/plugin.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Literal, Type, get_origin +from typing import Literal, get_origin from pydantic import BaseModel, model_validator from typing_extensions import Self @@ -27,8 +27,8 @@ def display_name(self) -> str: class Plugin(BaseModel): - task_cls: Type[ConfigurableTask] - config_cls: Type[ConfigBase] + task_cls: type[ConfigurableTask] + config_cls: type[ConfigBase] plugin_type: PluginType emoji: str = "🔌" diff --git a/src/data_designer/plugins/registry.py b/src/data_designer/plugins/registry.py index 010fe1a7..cd32895b 100644 --- a/src/data_designer/plugins/registry.py +++ b/src/data_designer/plugins/registry.py @@ -5,7 +5,7 @@ import os import threading from importlib.metadata import entry_points -from typing import Type, TypeAlias +from typing import TypeAlias from typing_extensions import Self @@ -37,7 +37,7 @@ def reset(cls) -> None: cls._plugins_discovered = False cls._plugins = {} - def add_plugin_types_to_union(self, type_union: Type[TypeAlias], plugin_type: PluginType) -> Type[TypeAlias]: + def add_plugin_types_to_union(self, type_union: type[TypeAlias], plugin_type: PluginType) -> type[TypeAlias]: for plugin in self.get_plugins(plugin_type): if plugin.config_cls not in type_union.__args__: type_union |= plugin.config_cls diff --git a/tests/config/utils/test_type_helpers.py b/tests/config/utils/test_type_helpers.py index 7365e092..cd8821f9 100644 --- a/tests/config/utils/test_type_helpers.py +++ b/tests/config/utils/test_type_helpers.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum -from typing import Literal, Union +from typing import Literal import pytest from pydantic import BaseModel @@ -49,7 +49,7 @@ class NotAModel: def test_create_str_enum_from_type_union_basic() -> None: - type_union = Union[StubModelA, StubModelB] + type_union = StubModelA | StubModelB result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") assert issubclass(result, Enum) @@ -64,7 +64,7 @@ def test_create_str_enum_from_type_union_basic() -> None: def test_create_str_enum_from_type_union_with_dashes() -> None: - type_union = Union[StubModelC, StubModelA] + type_union = StubModelC | StubModelA result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") assert hasattr(result, "TYPE_C_WITH_DASHES") @@ -72,7 +72,7 @@ def test_create_str_enum_from_type_union_with_dashes() -> None: def test_create_str_enum_from_type_union_multiple_models() -> None: - type_union = Union[StubModelA, StubModelB, StubModelC] + type_union = StubModelA | StubModelB | StubModelC result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") assert len(result) == 4 @@ -87,7 +87,7 @@ class StubModelD(BaseModel): column_type: Literal["type-a"] = "type-a" extra: str - type_union = Union[StubModelA, StubModelD] + type_union = StubModelA | StubModelD result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") assert len(result) == 2 @@ -96,14 +96,14 @@ class StubModelD(BaseModel): def test_create_str_enum_from_type_union_not_pydantic_model() -> None: - type_union = Union[StubModelA, NotAModel] + type_union = StubModelA | NotAModel with pytest.raises(InvalidTypeUnionError, match="must be a subclass of pydantic.BaseModel"): create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") def test_create_str_enum_from_type_union_invalid_discriminator_field() -> None: - type_union = Union[StubModelA, StubModelWithoutDiscriminator] + type_union = StubModelA | StubModelWithoutDiscriminator with pytest.raises(InvalidDiscriminatorFieldError, match="'column_type' is not a field of"): create_str_enum_from_discriminated_type_union("TestEnum", type_union, "column_type") @@ -121,7 +121,7 @@ class StubModelF(BaseModel): type_field: Literal["another-type"] = "another-type" value: int - type_union = Union[StubModelE, StubModelF] + type_union = StubModelE | StubModelF result = create_str_enum_from_discriminated_type_union("TestEnum", type_union, "type_field") assert hasattr(result, "CUSTOM_TYPE") diff --git a/tests/engine/test_configurable_task.py b/tests/engine/test_configurable_task.py index b3306a13..54ec9547 100644 --- a/tests/engine/test_configurable_task.py +++ b/tests/engine/test_configurable_task.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Type from unittest.mock import Mock import pandas as pd @@ -49,7 +48,7 @@ class TestConfig(ConfigBase): class TestTask(ConfigurableTask[TestConfig]): @classmethod - def get_config_type(cls) -> Type[TestConfig]: + def get_config_type(cls) -> type[TestConfig]: return TestConfig @classmethod @@ -83,7 +82,7 @@ class TestConfig(ConfigBase): class TestTask(ConfigurableTask[TestConfig]): @classmethod - def get_config_type(cls) -> Type[TestConfig]: + def get_config_type(cls) -> type[TestConfig]: return TestConfig @classmethod @@ -117,7 +116,7 @@ class TestConfig(ConfigBase): class TestTask(ConfigurableTask[TestConfig]): @classmethod - def get_config_type(cls) -> Type[TestConfig]: + def get_config_type(cls) -> type[TestConfig]: return TestConfig @classmethod diff --git a/tests/plugins/test_plugin_registry.py b/tests/plugins/test_plugin_registry.py index 3888487a..5a7feb5c 100644 --- a/tests/plugins/test_plugin_registry.py +++ b/tests/plugins/test_plugin_registry.py @@ -269,7 +269,6 @@ def test_plugin_registry_get_plugin_names(mock_plugin_discovery, mock_entry_poin def test_plugin_registry_update_type_union(mock_plugin_discovery, mock_entry_points: list[MagicMock]) -> None: """Test update_type_union() adds plugin config types to union.""" - from typing import Union from typing_extensions import TypeAlias @@ -280,7 +279,7 @@ class DummyConfig(ConfigBase): manager = PluginRegistry() # Create a Union with at least 2 types so it has __args__ - type_union: TypeAlias = Union[ConfigBase, DummyConfig] + type_union: TypeAlias = ConfigBase | DummyConfig updated_union = manager.add_plugin_types_to_union(type_union, PluginType.COLUMN_GENERATOR) assert StubPluginConfigA in updated_union.__args__