Skip to content

Commit 7c88230

Browse files
authored
chore: porting nmp (#11)
* porting nmp * remove unused seed dataset datastore tests * remove load_dataset
1 parent 62d80c0 commit 7c88230

File tree

103 files changed

+884
-937
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+884
-937
lines changed

src/data_designer/config/analysis/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/data_designer/config/analysis/column_profilers.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
from abc import ABC
55
from enum import Enum
6-
from typing import TypeAlias
6+
from typing import Optional, Union
77

88
from pydantic import BaseModel, Field
99
from rich.panel import Panel
1010
from rich.table import Column, Table
11+
from typing_extensions import TypeAlias
1112

1213
from ..base import ConfigBase
1314
from ..utils.visualization import ColorPalette
@@ -37,20 +38,20 @@ def create_report_section(self) -> Panel:
3738

3839
class JudgeScoreProfilerConfig(ConfigBase):
3940
model_alias: str
40-
summary_score_sample_size: int | None = Field(default=20, ge=1)
41+
summary_score_sample_size: Optional[int] = Field(default=20, ge=1)
4142

4243

4344
class JudgeScoreSample(BaseModel):
44-
score: int | str
45+
score: Union[int, str]
4546
reasoning: str
4647

4748

4849
class JudgeScoreDistributions(BaseModel):
49-
scores: dict[str, list[int | str]]
50+
scores: dict[str, list[Union[int, str]]]
5051
reasoning: dict[str, list[str]]
5152
distribution_types: dict[str, ColumnDistributionType]
52-
distributions: dict[str, CategoricalDistribution | NumericalDistribution | MissingValue]
53-
histograms: dict[str, CategoricalHistogramData | MissingValue]
53+
distributions: dict[str, Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
54+
histograms: dict[str, Union[CategoricalHistogramData, MissingValue]]
5455

5556

5657
class JudgeScoreSummary(BaseModel):
@@ -62,7 +63,7 @@ class JudgeScoreSummary(BaseModel):
6263
class JudgeScoreProfilerResults(ColumnProfilerResults):
6364
column_name: str
6465
summaries: dict[str, JudgeScoreSummary]
65-
score_distributions: JudgeScoreDistributions | MissingValue
66+
score_distributions: Union[JudgeScoreDistributions, MissingValue]
6667

6768
def create_report_section(self) -> Panel:
6869
layout = Table.grid(Column(), expand=True, padding=(2, 0))

src/data_designer/config/analysis/column_statistics.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
from abc import ABC, abstractmethod
77
from enum import Enum
8-
from typing import Annotated, Any, Literal, TypeAlias
8+
from typing import Annotated, Any, Literal, Optional, Union
99

1010
from pandas import Series
1111
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
12-
from typing_extensions import Self
12+
from typing_extensions import Self, TypeAlias
1313

1414
from ..columns import DataDesignerColumnType
1515
from ..sampler_params import SamplerType
@@ -39,27 +39,27 @@ def create_report_row_data(self) -> dict[str, str]: ...
3939

4040
class GeneralColumnStatistics(BaseColumnStatistics):
4141
column_name: str
42-
num_records: int | MissingValue
43-
num_null: int | MissingValue
44-
num_unique: int | MissingValue
42+
num_records: Union[int, MissingValue]
43+
num_null: Union[int, MissingValue]
44+
num_unique: Union[int, MissingValue]
4545
pyarrow_dtype: str
4646
simple_dtype: str
4747
column_type: Literal["general"] = "general"
4848

4949
@field_validator("num_null", "num_unique", "num_records", mode="before")
50-
def general_statistics_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
50+
def general_statistics_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
5151
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)
5252

5353
@property
54-
def percent_null(self) -> float | MissingValue:
54+
def percent_null(self) -> Union[float, MissingValue]:
5555
return (
5656
self.num_null
5757
if self._is_missing_value(self.num_null)
5858
else prepare_number_for_reporting(100 * self.num_null / (self.num_records + EPSILON), float)
5959
)
6060

6161
@property
62-
def percent_unique(self) -> float | MissingValue:
62+
def percent_unique(self) -> Union[float, MissingValue]:
6363
return (
6464
self.num_unique
6565
if self._is_missing_value(self.num_unique)
@@ -78,17 +78,17 @@ def _general_display_row(self) -> dict[str, str]:
7878
def create_report_row_data(self) -> dict[str, str]:
7979
return self._general_display_row
8080

81-
def _is_missing_value(self, v: float | int | MissingValue) -> bool:
81+
def _is_missing_value(self, v: Union[float, int, MissingValue]) -> bool:
8282
return v in set(MissingValue)
8383

8484

8585
class LLMTextColumnStatistics(GeneralColumnStatistics):
86-
completion_tokens_mean: float | MissingValue
87-
completion_tokens_median: float | MissingValue
88-
completion_tokens_stddev: float | MissingValue
89-
prompt_tokens_mean: float | MissingValue
90-
prompt_tokens_median: float | MissingValue
91-
prompt_tokens_stddev: float | MissingValue
86+
completion_tokens_mean: Union[float, MissingValue]
87+
completion_tokens_median: Union[float, MissingValue]
88+
completion_tokens_stddev: Union[float, MissingValue]
89+
prompt_tokens_mean: Union[float, MissingValue]
90+
prompt_tokens_median: Union[float, MissingValue]
91+
prompt_tokens_stddev: Union[float, MissingValue]
9292
column_type: Literal[DataDesignerColumnType.LLM_TEXT.value] = DataDesignerColumnType.LLM_TEXT.value
9393

9494
@field_validator(
@@ -100,7 +100,7 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
100100
"prompt_tokens_stddev",
101101
mode="before",
102102
)
103-
def llm_column_ensure_python_floats(cls, v: float | int | MissingValue) -> float | int | MissingValue:
103+
def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> Union[float, int, MissingValue]:
104104
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, float)
105105

106106
def create_report_row_data(self) -> dict[str, Any]:
@@ -136,7 +136,7 @@ class LLMJudgedColumnStatistics(LLMTextColumnStatistics):
136136
class SamplerColumnStatistics(GeneralColumnStatistics):
137137
sampler_type: SamplerType
138138
distribution_type: ColumnDistributionType
139-
distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None
139+
distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
140140
column_type: Literal[DataDesignerColumnType.SAMPLER.value] = DataDesignerColumnType.SAMPLER.value
141141

142142
def create_report_row_data(self) -> dict[str, str]:
@@ -148,7 +148,7 @@ def create_report_row_data(self) -> dict[str, str]:
148148

149149
class SeedDatasetColumnStatistics(GeneralColumnStatistics):
150150
distribution_type: ColumnDistributionType
151-
distribution: CategoricalDistribution | NumericalDistribution | MissingValue | None
151+
distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
152152
column_type: Literal[DataDesignerColumnType.SEED_DATASET.value] = DataDesignerColumnType.SEED_DATASET.value
153153

154154
def create_report_row_data(self) -> dict[str, str]:
@@ -160,15 +160,15 @@ class ExpressionColumnStatistics(GeneralColumnStatistics):
160160

161161

162162
class ValidationColumnStatistics(GeneralColumnStatistics):
163-
num_valid_records: int | MissingValue
163+
num_valid_records: Union[int, MissingValue]
164164
column_type: Literal[DataDesignerColumnType.VALIDATION.value] = DataDesignerColumnType.VALIDATION.value
165165

166166
@field_validator("num_valid_records", mode="before")
167-
def code_validation_column_ensure_python_integers(cls, v: int | MissingValue) -> int | MissingValue:
167+
def code_validation_column_ensure_python_integers(cls, v: Union[int, MissingValue]) -> Union[int, MissingValue]:
168168
return v if isinstance(v, MissingValue) else prepare_number_for_reporting(v, int)
169169

170170
@property
171-
def percent_valid(self) -> float | MissingValue:
171+
def percent_valid(self) -> Union[float, MissingValue]:
172172
return (
173173
self.num_valid_records
174174
if self._is_missing_value(self.num_valid_records)
@@ -181,7 +181,7 @@ def create_report_row_data(self) -> dict[str, str]:
181181

182182

183183
class CategoricalHistogramData(BaseModel):
184-
categories: list[float | int | str]
184+
categories: list[Union[float, int, str]]
185185
counts: list[int]
186186

187187
@model_validator(mode="after")
@@ -198,12 +198,12 @@ def from_series(cls, series: Series) -> Self:
198198

199199

200200
class CategoricalDistribution(BaseModel):
201-
most_common_value: str | int
202-
least_common_value: str | int
201+
most_common_value: Union[str, int]
202+
least_common_value: Union[str, int]
203203
histogram: CategoricalHistogramData
204204

205205
@field_validator("most_common_value", "least_common_value", mode="before")
206-
def ensure_python_types(cls, v: str | int) -> str | int:
206+
def ensure_python_types(cls, v: Union[str, int]) -> Union[str, int]:
207207
return str(v) if not is_int(v) else prepare_number_for_reporting(v, int)
208208

209209
@classmethod
@@ -217,14 +217,14 @@ def from_series(cls, series: Series) -> Self:
217217

218218

219219
class NumericalDistribution(BaseModel):
220-
min: float | int
221-
max: float | int
220+
min: Union[float, int]
221+
max: Union[float, int]
222222
mean: float
223223
stddev: float
224224
median: float
225225

226226
@field_validator("min", "max", "mean", "stddev", "median", mode="before")
227-
def ensure_python_types(cls, v: float | int) -> float | int:
227+
def ensure_python_types(cls, v: Union[float, int]) -> Union[float, int]:
228228
return prepare_number_for_reporting(v, int if is_int(v) else float)
229229

230230
@classmethod
@@ -239,14 +239,16 @@ def from_series(cls, series: Series) -> Self:
239239

240240

241241
ColumnStatisticsT: TypeAlias = Annotated[
242-
GeneralColumnStatistics
243-
| LLMTextColumnStatistics
244-
| LLMCodeColumnStatistics
245-
| LLMStructuredColumnStatistics
246-
| LLMJudgedColumnStatistics
247-
| SamplerColumnStatistics
248-
| SeedDatasetColumnStatistics
249-
| ValidationColumnStatistics
250-
| ExpressionColumnStatistics,
242+
Union[
243+
GeneralColumnStatistics,
244+
LLMTextColumnStatistics,
245+
LLMCodeColumnStatistics,
246+
LLMStructuredColumnStatistics,
247+
LLMJudgedColumnStatistics,
248+
SamplerColumnStatistics,
249+
SeedDatasetColumnStatistics,
250+
ValidationColumnStatistics,
251+
ExpressionColumnStatistics,
252+
],
251253
Field(discriminator="column_type"),
252254
]

src/data_designer/config/analysis/dataset_profiler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from functools import cached_property
55
from pathlib import Path
6+
from typing import Optional, Union
67

78
from pydantic import BaseModel, Field, field_validator
89

@@ -18,8 +19,8 @@ class DatasetProfilerResults(BaseModel):
1819
num_records: int
1920
target_num_records: int
2021
column_statistics: list[ColumnStatisticsT] = Field(..., min_length=1)
21-
side_effect_column_names: list[str] | None = None
22-
column_profiles: list[ColumnProfilerResultsT] | None = None
22+
side_effect_column_names: Optional[list[str]] = None
23+
column_profiles: Optional[list[ColumnProfilerResultsT]] = None
2324

2425
@field_validator("num_records", "target_num_records", mode="before")
2526
def ensure_python_integers(cls, v: int) -> int:
@@ -42,8 +43,8 @@ def get_column_statistics_by_type(self, column_type: DataDesignerColumnType) ->
4243

4344
def to_report(
4445
self,
45-
save_path: str | Path | None = None,
46-
include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
46+
save_path: Optional[Union[str, Path]] = None,
47+
include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
4748
) -> None:
4849
"""Generate and print an analysis report based on the dataset profiling results.
4950

src/data_designer/config/analysis/utils/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/data_designer/config/analysis/utils/reporting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from enum import Enum
77
from pathlib import Path
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Optional, Union
99

1010
from rich.align import Align
1111
from rich.console import Console, Group
@@ -49,8 +49,8 @@ class ReportSection(str, Enum):
4949

5050
def generate_analysis_report(
5151
analysis: DatasetProfilerResults,
52-
save_path: str | Path | None = None,
53-
include_sections: list[ReportSection | DataDesignerColumnType] | None = None,
52+
save_path: Optional[Union[str, Path]] = None,
53+
include_sections: Optional[list[Union[ReportSection, DataDesignerColumnType]]] = None,
5454
) -> None:
5555
"""Generate an analysis report for dataset profiling results.
5656
@@ -166,7 +166,7 @@ def create_judge_score_summary_table(
166166
layout = Table.grid(Column(), Column(), expand=True, padding=(0, 2))
167167

168168
histogram_table = create_rich_histogram_table(
169-
{str(s): c for s, c in zip(histogram.categories, histogram.counts, strict=False)},
169+
{str(s): c for s, c in zip(histogram.categories, histogram.counts)},
170170
("score", "count"),
171171
name_style=HIST_NAME_STYLE,
172172
value_style=HIST_VALUE_STYLE,

src/data_designer/config/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from abc import ABC, abstractmethod
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
8+
from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar, Union
99

1010
import pandas as pd
1111
from pydantic import BaseModel, ConfigDict
@@ -14,9 +14,9 @@
1414
from .utils.io_helpers import serialize_data
1515

1616
if TYPE_CHECKING:
17-
from ..client.results.preview import PreviewResults
1817
from .analysis.dataset_profiler import DatasetProfilerResults
1918
from .config_builder import DataDesignerConfigBuilder
19+
from .preview_results import PreviewResults
2020

2121
DEFAULT_NUM_RECORDS = 10
2222

@@ -66,7 +66,7 @@ def to_dict(self) -> dict[str, Any]:
6666
"""
6767
return self.model_dump(mode="json")
6868

69-
def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
69+
def to_yaml(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
7070
"""Convert the configuration to a YAML string or file.
7171
7272
Args:
@@ -84,7 +84,7 @@ def to_yaml(self, path: str | Path | None = None, *, indent: int | None = 2, **k
8484
with open(path, "w") as f:
8585
f.write(yaml_str)
8686

87-
def to_json(self, path: str | Path | None = None, *, indent: int | None = 2, **kwargs) -> str | None:
87+
def to_json(self, path: Optional[Union[str, Path]] = None, *, indent: Optional[int] = 2, **kwargs) -> Optional[str]:
8888
"""Convert the configuration to a JSON string or file.
8989
9090
Args:

0 commit comments

Comments
 (0)