Skip to content

Commit 8540529

Browse files
authored
chore: update type hints to 3.10+ (#148)
* update pyproject to allow ruff to fix type hints * lets go type hints * move to select
1 parent abb5d62 commit 8540529

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+311
-349
lines changed

pyproject.toml

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,25 +105,19 @@ required-version = ">=0.7.10"
105105
[tool.ruff]
106106
line-length = 120
107107
indent-width = 4
108+
target-version = "py310"
108109

109110
[tool.ruff.lint]
110111
select = [
111-
# "E", # pycodestyle errors
112-
"W", # pycodestyle warnings
113-
"F", # pyflakes
114-
"I", # isort (import sorting)
115-
# "N", # pep8-naming
116-
# "UP", # pyupgrade (modern Python syntax)
117-
# "ANN", # flake8-annotations (enforce type hints)
118-
# "B", # fla e8-bugbear (common bugs)
119-
# "C4", # flake8-comprehensions
120-
# "DTZ", # flake8-datetimez (timezone awareness)
121-
"ICN", # flake8-import-conventions
122-
"PIE", # flake8-pie (misc lints)
123-
# "RET", # flake8-return
124-
# "SIM", # flake8-simplify
125-
# "PTH", # flake8-use-pathlib
126-
"TID", # flake8-tidy-imports (ban relative imports)
112+
"W", # pycodestyle warnings
113+
"F", # pyflakes
114+
"I", # isort (import sorting)
115+
"ICN", # flake8-import-conventions
116+
"PIE", # flake8-pie (misc lints)
117+
"TID", # flake8-tidy-imports (ban relative imports)
118+
"UP006", # List[A] -> list[A]
119+
"UP007", # Union[A, B] -> A | B
120+
"UP045", # Optional[A] -> A | None
127121
]
128122
ignore = [
129123
"ANN401", # Dynamically typed expressions (Any)

src/data_designer/config/analysis/column_profilers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from abc import ABC
55
from enum import Enum
6-
from typing import Optional, Union
76

87
from pydantic import BaseModel, Field
98
from rich.panel import Panel
@@ -61,7 +60,7 @@ class JudgeScoreProfilerConfig(ConfigBase):
6160
"""
6261

6362
model_alias: str
64-
summary_score_sample_size: Optional[int] = Field(default=20, ge=1)
63+
summary_score_sample_size: int | None = Field(default=20, ge=1)
6564

6665

6766
class JudgeScoreSample(BaseModel):
@@ -75,7 +74,7 @@ class JudgeScoreSample(BaseModel):
7574
reasoning: The reasoning or explanation provided by the judge for this score.
7675
"""
7776

78-
score: Union[int, str]
77+
score: int | str
7978
reasoning: str
8079

8180

@@ -94,11 +93,11 @@ class JudgeScoreDistributions(BaseModel):
9493
histograms: Mapping of each score dimension name to its histogram data.
9594
"""
9695

97-
scores: dict[str, list[Union[int, str]]]
96+
scores: dict[str, list[int | str]]
9897
reasoning: dict[str, list[str]]
9998
distribution_types: dict[str, ColumnDistributionType]
100-
distributions: dict[str, Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
101-
histograms: dict[str, Union[CategoricalHistogramData, MissingValue]]
99+
distributions: dict[str, CategoricalDistribution | NumericalDistribution | MissingValue]
100+
histograms: dict[str, CategoricalHistogramData | MissingValue]
102101

103102

104103
class JudgeScoreSummary(BaseModel):
@@ -132,7 +131,7 @@ class JudgeScoreProfilerResults(ColumnProfilerResults):
132131

133132
column_name: str
134133
summaries: dict[str, JudgeScoreSummary]
135-
score_distributions: Union[JudgeScoreDistributions, MissingValue]
134+
score_distributions: JudgeScoreDistributions | MissingValue
136135

137136
def create_report_section(self) -> Panel:
138137
layout = Table.grid(Column(), expand=True, padding=(2, 0))

src/data_designer/config/analysis/column_statistics.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from abc import ABC, abstractmethod
77
from enum import Enum
8-
from typing import Any, Literal, Optional, Union
8+
from typing import Any, Literal
99

1010
from pandas import Series
1111
from pydantic import BaseModel, ConfigDict, create_model, field_validator, model_validator
@@ -69,27 +69,27 @@ class GeneralColumnStatistics(BaseColumnStatistics):
6969
"""
7070

7171
column_name: str
72-
num_records: Union[int, MissingValue]
73-
num_null: Union[int, MissingValue]
74-
num_unique: Union[int, MissingValue]
72+
num_records: int | MissingValue
73+
num_null: int | MissingValue
74+
num_unique: int | MissingValue
7575
pyarrow_dtype: str
7676
simple_dtype: str
7777
column_type: Literal["general"] = "general"
7878

7979
@field_validator("num_null", "num_unique", "num_records", mode="before")
80-
def general_statistics_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
80+
def general_statistics_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
8181
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)
8282

8383
@property
84-
def percent_null(self) -> Union[float, MissingValue]:
84+
def percent_null(self) -> float | MissingValue:
8585
return (
8686
self.num_null
8787
if self._is_missing_value(self.num_null)
8888
else prepare_number_for_reporting(100 * self.num_null / (self.num_records + EPSILON), float)
8989
)
9090

9191
@property
92-
def percent_unique(self) -> Union[float, MissingValue]:
92+
def percent_unique(self) -> float | MissingValue:
9393
return (
9494
self.num_unique
9595
if self._is_missing_value(self.num_unique)
@@ -108,7 +108,7 @@ def _general_display_row(self) -> dict[str, str]:
108108
def create_report_row_data(self) -> dict[str, str]:
109109
return self._general_display_row
110110

111-
def _is_missing_value(self, v: Union[float, int, MissingValue]) -> bool:
111+
def _is_missing_value(self, v: float | int | MissingValue) -> bool:
112112
return v in set(MissingValue)
113113

114114

@@ -128,12 +128,12 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
128128
column_type: Discriminator field, always "llm-text" for this statistics type.
129129
"""
130130

131-
output_tokens_mean: Union[float, MissingValue]
132-
output_tokens_median: Union[float, MissingValue]
133-
output_tokens_stddev: Union[float, MissingValue]
134-
input_tokens_mean: Union[float, MissingValue]
135-
input_tokens_median: Union[float, MissingValue]
136-
input_tokens_stddev: Union[float, MissingValue]
131+
output_tokens_mean: float | MissingValue
132+
output_tokens_median: float | MissingValue
133+
output_tokens_stddev: float | MissingValue
134+
input_tokens_mean: float | MissingValue
135+
input_tokens_median: float | MissingValue
136+
input_tokens_stddev: float | MissingValue
137137
column_type: Literal[DataDesignerColumnType.LLM_TEXT.value] = DataDesignerColumnType.LLM_TEXT.value
138138

139139
@field_validator(
@@ -145,7 +145,7 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
145145
"input_tokens_stddev",
146146
mode="before",
147147
)
148-
def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> Union[float, int, MissingValue]:
148+
def llm_column_ensure_python_floats(cls, v: float | int | MissingValue) -> float | int | MissingValue:
149149
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, float)
150150

151151
def create_report_row_data(self) -> dict[str, Any]:
@@ -225,7 +225,7 @@ class SamplerColumnStatistics(GeneralColumnStatistics):
225225

226226
sampler_type: SamplerType
227227
distribution_type: ColumnDistributionType
228-
distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
228+
distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None
229229
column_type: Literal[DataDesignerColumnType.SAMPLER.value] = DataDesignerColumnType.SAMPLER.value
230230

231231
def create_report_row_data(self) -> dict[str, str]:
@@ -273,15 +273,15 @@ class ValidationColumnStatistics(GeneralColumnStatistics):
273273
column_type: Discriminator field, always "validation" for this statistics type.
274274
"""
275275

276-
num_valid_records: Union[int, MissingValue]
276+
num_valid_records: int | MissingValue
277277
column_type: Literal[DataDesignerColumnType.VALIDATION.value] = DataDesignerColumnType.VALIDATION.value
278278

279279
@field_validator("num_valid_records", mode="before")
280-
def code_validation_column_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
280+
def code_validation_column_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
281281
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)
282282

283283
@property
284-
def percent_valid(self) -> Union[float, MissingValue]:
284+
def percent_valid(self) -> float | MissingValue:
285285
return (
286286
self.num_valid_records
287287
if self._is_missing_value(self.num_valid_records)
@@ -303,7 +303,7 @@ class CategoricalHistogramData(BaseModel):
303303
counts: List of occurrence counts for each category.
304304
"""
305305

306-
categories: list[Union[float, int, str]]
306+
categories: list[float | int | str]
307307
counts: list[int]
308308

309309
@model_validator(mode="after")
@@ -328,12 +328,12 @@ class CategoricalDistribution(BaseModel):
328328
histogram: Complete frequency distribution showing all categories and their counts.
329329
"""
330330

331-
most_common_value: Union[str, int]
332-
least_common_value: Union[str, int]
331+
most_common_value: str | int
332+
least_common_value: str | int
333333
histogram: CategoricalHistogramData
334334

335335
@field_validator("most_common_value", "least_common_value", mode="before")
336-
def ensure_python_types(cls, v: Union[str, int]) -> Union[str, int]:
336+
def ensure_python_types(cls, v: str | int) -> str | int:
337337
return str(v) if not is_int(v) else prepare_number_for_reporting(v, int)
338338

339339
@classmethod
@@ -357,14 +357,14 @@ class NumericalDistribution(BaseModel):
357357
median: Median value of the distribution.
358358
"""
359359

360-
min: Union[float, int]
361-
max: Union[float, int]
360+
min: float | int
361+
max: float | int
362362
mean: float
363363
stddev: float
364364
median: float
365365

366366
@field_validator("min", "max", "mean", "stddev", "median", mode="before")
367-
def ensure_python_types(cls, v: Union[float, int]) -> Union[float, int]:
367+
def ensure_python_types(cls, v: float | int) -> float | int:
368368
return prepare_number_for_reporting(v, int if is_int(v) else float)
369369

370370
@classmethod
@@ -378,17 +378,17 @@ def from_series(cls, series: Series) -> Self:
378378
)
379379

380380

381-
ColumnStatisticsT: TypeAlias = Union[
382-
GeneralColumnStatistics,
383-
LLMTextColumnStatistics,
384-
LLMCodeColumnStatistics,
385-
LLMStructuredColumnStatistics,
386-
LLMJudgedColumnStatistics,
387-
SamplerColumnStatistics,
388-
SeedDatasetColumnStatistics,
389-
ValidationColumnStatistics,
390-
ExpressionColumnStatistics,
391-
]
381+
ColumnStatisticsT: TypeAlias = (
382+
GeneralColumnStatistics
383+
| LLMTextColumnStatistics
384+
| LLMCodeColumnStatistics
385+
| LLMStructuredColumnStatistics
386+
| LLMJudgedColumnStatistics
387+
| SamplerColumnStatistics
388+
| SeedDatasetColumnStatistics
389+
| ValidationColumnStatistics
390+
| ExpressionColumnStatistics
391+
)
392392

393393

394394
DEFAULT_COLUMN_STATISTICS_MAP = {

src/data_designer/config/analysis/dataset_profiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from functools import cached_property
55
from pathlib import Path
6-
from typing import Annotated, Optional, Union
6+
from typing import Annotated
77

88
from pydantic import BaseModel, Field, field_validator
99

@@ -34,8 +34,8 @@ class DatasetProfilerResults(BaseModel):
3434
num_records: int
3535
target_num_records: int
3636
column_statistics: list[Annotated[ColumnStatisticsT, Field(discriminator="column_type")]] = Field(..., min_length=1)
37-
side_effect_column_names: Optional[list[str]] = None
38-
column_profiles: Optional[list[ColumnProfilerResultsT]] = None
37+
side_effect_column_names: list[str] | None = None
38+
column_profiles: list[ColumnProfilerResultsT] | None = None
3939

4040
@field_validator("num_records", "target_num_records", mode="before")
4141
def ensure_python_integers(cls, v: int) -> int:
@@ -61,8 +61,8 @@ def get_column_statistics_by_type(self, column_type: DataDesignerColumnType) ->
6161

6262
def to_report(
6363
self,
64-
save_path: Optional[Union[str, Path]] = None,
65-
include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
64+
save_path: str | Path | None = None,
65+
include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
6666
) -> None:
6767
"""Generate and print an analysis report based on the dataset profiling results.
6868

src/data_designer/config/analysis/utils/reporting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from enum import Enum
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Optional, Union
8+
from typing import TYPE_CHECKING
99

1010
from rich.align import Align
1111
from rich.console import Console, Group
@@ -48,8 +48,8 @@ class ReportSection(str, Enum):
4848

4949
def generate_analysis_report(
5050
analysis: DatasetProfilerResults,
51-
save_path: Optional[Union[str, Path]] = None,
52-
include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
51+
save_path: str | Path | None = None,
52+
include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
5353
) -> None:
5454
"""Generate an analysis report for dataset profiling results.
5555

src/data_designer/config/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import annotations
55

66
from pathlib import Path
7-
from typing import Any, Optional, Union
7+
from typing import Any
88

99
import yaml
1010
from pydantic import BaseModel, ConfigDict
@@ -31,7 +31,7 @@ def to_dict(self) -> dict[str, Any]:
3131
"""
3232
return self.model_dump(mode="json")
3333

34-
def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
34+
def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
3535
"""Convert the configuration to a YAML string or file.
3636
3737
Args:
@@ -49,7 +49,7 @@ def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[i
4949
with open(path, "w") as f:
5050
f.write(yaml_str)
5151

52-
def to_json(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
52+
def to_json(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
5353
"""Convert the configuration to a JSON string or file.
5454
5555
Args:

src/data_designer/config/column_configs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from abc import ABC
5-
from typing import Annotated, Literal, Optional, Type, Union
5+
from typing import Annotated, Literal
66

77
from pydantic import BaseModel, Discriminator, Field, model_validator
88
from typing_extensions import Self
@@ -91,7 +91,7 @@ class SamplerColumnConfig(SingleColumnConfig):
9191
sampler_type: SamplerType
9292
params: Annotated[SamplerParamsT, Discriminator("sampler_type")]
9393
conditional_params: dict[str, Annotated[SamplerParamsT, Discriminator("sampler_type")]] = {}
94-
convert_to: Optional[str] = None
94+
convert_to: str | None = None
9595
column_type: Literal["sampler"] = "sampler"
9696

9797
@model_validator(mode="before")
@@ -146,8 +146,8 @@ class LLMTextColumnConfig(SingleColumnConfig):
146146

147147
prompt: str
148148
model_alias: str
149-
system_prompt: Optional[str] = None
150-
multi_modal_context: Optional[list[ImageContext]] = None
149+
system_prompt: str | None = None
150+
multi_modal_context: list[ImageContext] | None = None
151151
column_type: Literal["llm-text"] = "llm-text"
152152

153153
@property
@@ -222,7 +222,7 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig):
222222
column_type: Discriminator field, always "llm-structured" for this configuration type.
223223
"""
224224

225-
output_format: Union[dict, Type[BaseModel]]
225+
output_format: dict | type[BaseModel]
226226
column_type: Literal["llm-structured"] = "llm-structured"
227227

228228
@model_validator(mode="after")
@@ -255,7 +255,7 @@ class Score(ConfigBase):
255255

256256
name: str = Field(..., description="A clear name for this score.")
257257
description: str = Field(..., description="An informative and detailed assessment guide for using this score.")
258-
options: dict[Union[int, str], str] = Field(..., description="Score options in the format of {score: description}.")
258+
options: dict[int | str, str] = Field(..., description="Score options in the format of {score: description}.")
259259

260260

261261
class LLMJudgeColumnConfig(LLMTextColumnConfig):

0 commit comments

Comments
 (0)