55
66from abc import ABC , abstractmethod
77from enum import Enum
8- from typing import Any , Literal , Optional , Union
8+ from typing import Any , Literal
99
1010from pandas import Series
1111from pydantic import BaseModel , ConfigDict , create_model , field_validator , model_validator
@@ -69,27 +69,27 @@ class GeneralColumnStatistics(BaseColumnStatistics):
6969 """
7070
7171 column_name : str
72- num_records : Union [ int , MissingValue ]
73- num_null : Union [ int , MissingValue ]
74- num_unique : Union [ int , MissingValue ]
72+ num_records : int | MissingValue
73+ num_null : int | MissingValue
74+ num_unique : int | MissingValue
7575 pyarrow_dtype : str
7676 simple_dtype : str
7777 column_type : Literal ["general" ] = "general"
7878
7979 @field_validator ("num_null" , "num_unique" , "num_records" , mode = "before" )
80- def general_statistics_ensure_python_integers (cls , v : Union [ int , MissingValue ] ) -> Union [ int , MissingValue ] :
80+ def general_statistics_ensure_python_integers (cls , v : int | MissingValue ) -> int | MissingValue :
8181 return v if isinstance (v , MissingValue ) else prepare_number_for_reporting (v , int )
8282
8383 @property
84- def percent_null (self ) -> Union [ float , MissingValue ] :
84+ def percent_null (self ) -> float | MissingValue :
8585 return (
8686 self .num_null
8787 if self ._is_missing_value (self .num_null )
8888 else prepare_number_for_reporting (100 * self .num_null / (self .num_records + EPSILON ), float )
8989 )
9090
9191 @property
92- def percent_unique (self ) -> Union [ float , MissingValue ] :
92+ def percent_unique (self ) -> float | MissingValue :
9393 return (
9494 self .num_unique
9595 if self ._is_missing_value (self .num_unique )
@@ -108,7 +108,7 @@ def _general_display_row(self) -> dict[str, str]:
108108 def create_report_row_data (self ) -> dict [str , str ]:
109109 return self ._general_display_row
110110
111- def _is_missing_value (self , v : Union [ float , int , MissingValue ] ) -> bool :
111+ def _is_missing_value (self , v : float | int | MissingValue ) -> bool :
112112 return v in set (MissingValue )
113113
114114
@@ -128,12 +128,12 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
128128 column_type: Discriminator field, always "llm-text" for this statistics type.
129129 """
130130
131- output_tokens_mean : Union [ float , MissingValue ]
132- output_tokens_median : Union [ float , MissingValue ]
133- output_tokens_stddev : Union [ float , MissingValue ]
134- input_tokens_mean : Union [ float , MissingValue ]
135- input_tokens_median : Union [ float , MissingValue ]
136- input_tokens_stddev : Union [ float , MissingValue ]
131+ output_tokens_mean : float | MissingValue
132+ output_tokens_median : float | MissingValue
133+ output_tokens_stddev : float | MissingValue
134+ input_tokens_mean : float | MissingValue
135+ input_tokens_median : float | MissingValue
136+ input_tokens_stddev : float | MissingValue
137137 column_type : Literal [DataDesignerColumnType .LLM_TEXT .value ] = DataDesignerColumnType .LLM_TEXT .value
138138
139139 @field_validator (
@@ -145,7 +145,7 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
145145 "input_tokens_stddev" ,
146146 mode = "before" ,
147147 )
148- def llm_column_ensure_python_floats (cls , v : Union [ float , int , MissingValue ] ) -> Union [ float , int , MissingValue ] :
148+ def llm_column_ensure_python_floats (cls , v : float | int | MissingValue ) -> float | int | MissingValue :
149149 return v if isinstance (v , MissingValue ) else prepare_number_for_reporting (v , float )
150150
151151 def create_report_row_data (self ) -> dict [str , Any ]:
@@ -225,7 +225,7 @@ class SamplerColumnStatistics(GeneralColumnStatistics):
225225
226226 sampler_type : SamplerType
227227 distribution_type : ColumnDistributionType
228- distribution : Optional [ Union [ CategoricalDistribution , NumericalDistribution , MissingValue ]]
228+ distribution : CategoricalDistribution | NumericalDistribution | MissingValue | None
229229 column_type : Literal [DataDesignerColumnType .SAMPLER .value ] = DataDesignerColumnType .SAMPLER .value
230230
231231 def create_report_row_data (self ) -> dict [str , str ]:
@@ -273,15 +273,15 @@ class ValidationColumnStatistics(GeneralColumnStatistics):
273273 column_type: Discriminator field, always "validation" for this statistics type.
274274 """
275275
276- num_valid_records : Union [ int , MissingValue ]
276+ num_valid_records : int | MissingValue
277277 column_type : Literal [DataDesignerColumnType .VALIDATION .value ] = DataDesignerColumnType .VALIDATION .value
278278
279279 @field_validator ("num_valid_records" , mode = "before" )
280- def code_validation_column_ensure_python_integers (cls , v : Union [ int , MissingValue ] ) -> Union [ int , MissingValue ] :
280+ def code_validation_column_ensure_python_integers (cls , v : int | MissingValue ) -> int | MissingValue :
281281 return v if isinstance (v , MissingValue ) else prepare_number_for_reporting (v , int )
282282
283283 @property
284- def percent_valid (self ) -> Union [ float , MissingValue ] :
284+ def percent_valid (self ) -> float | MissingValue :
285285 return (
286286 self .num_valid_records
287287 if self ._is_missing_value (self .num_valid_records )
@@ -303,7 +303,7 @@ class CategoricalHistogramData(BaseModel):
303303 counts: List of occurrence counts for each category.
304304 """
305305
306- categories : list [Union [ float , int , str ] ]
306+ categories : list [float | int | str ]
307307 counts : list [int ]
308308
309309 @model_validator (mode = "after" )
@@ -328,12 +328,12 @@ class CategoricalDistribution(BaseModel):
328328 histogram: Complete frequency distribution showing all categories and their counts.
329329 """
330330
331- most_common_value : Union [ str , int ]
332- least_common_value : Union [ str , int ]
331+ most_common_value : str | int
332+ least_common_value : str | int
333333 histogram : CategoricalHistogramData
334334
335335 @field_validator ("most_common_value" , "least_common_value" , mode = "before" )
336- def ensure_python_types (cls , v : Union [ str , int ] ) -> Union [ str , int ] :
336+ def ensure_python_types (cls , v : str | int ) -> str | int :
337337 return str (v ) if not is_int (v ) else prepare_number_for_reporting (v , int )
338338
339339 @classmethod
@@ -357,14 +357,14 @@ class NumericalDistribution(BaseModel):
357357 median: Median value of the distribution.
358358 """
359359
360- min : Union [ float , int ]
361- max : Union [ float , int ]
360+ min : float | int
361+ max : float | int
362362 mean : float
363363 stddev : float
364364 median : float
365365
366366 @field_validator ("min" , "max" , "mean" , "stddev" , "median" , mode = "before" )
367- def ensure_python_types (cls , v : Union [ float , int ] ) -> Union [ float , int ] :
367+ def ensure_python_types (cls , v : float | int ) -> float | int :
368368 return prepare_number_for_reporting (v , int if is_int (v ) else float )
369369
370370 @classmethod
@@ -378,17 +378,17 @@ def from_series(cls, series: Series) -> Self:
378378 )
379379
380380
381- ColumnStatisticsT : TypeAlias = Union [
382- GeneralColumnStatistics ,
383- LLMTextColumnStatistics ,
384- LLMCodeColumnStatistics ,
385- LLMStructuredColumnStatistics ,
386- LLMJudgedColumnStatistics ,
387- SamplerColumnStatistics ,
388- SeedDatasetColumnStatistics ,
389- ValidationColumnStatistics ,
390- ExpressionColumnStatistics ,
391- ]
381+ ColumnStatisticsT : TypeAlias = (
382+ GeneralColumnStatistics
383+ | LLMTextColumnStatistics
384+ | LLMCodeColumnStatistics
385+ | LLMStructuredColumnStatistics
386+ | LLMJudgedColumnStatistics
387+ | SamplerColumnStatistics
388+ | SeedDatasetColumnStatistics
389+ | ValidationColumnStatistics
390+ | ExpressionColumnStatistics
391+ )
392392
393393
394394DEFAULT_COLUMN_STATISTICS_MAP = {
0 commit comments