55
66from abc import ABC , abstractmethod
77from enum import Enum
8- from typing import Annotated , Any , Literal , TypeAlias
8+ from typing import Annotated , Any , Literal , Optional , Union
99
1010from pandas import Series
1111from pydantic import BaseModel , ConfigDict , Field , field_validator , model_validator
12- from typing_extensions import Self
12+ from typing_extensions import Self , TypeAlias
1313
1414from ..columns import DataDesignerColumnType
1515from ..sampler_params import SamplerType
@@ -39,27 +39,27 @@ def create_report_row_data(self) -> dict[str, str]: ...
3939
4040class GeneralColumnStatistics (BaseColumnStatistics ):
4141 column_name : str
42- num_records : int | MissingValue
43- num_null : int | MissingValue
44- num_unique : int | MissingValue
42+ num_records : Union [ int , MissingValue ]
43+ num_null : Union [ int , MissingValue ]
44+ num_unique : Union [ int , MissingValue ]
4545 pyarrow_dtype : str
4646 simple_dtype : str
4747 column_type : Literal ["general" ] = "general"
4848
4949 @field_validator ("num_null" , "num_unique" , "num_records" , mode = "before" )
50- def general_statistics_ensure_python_integers (cls , v : int | MissingValue ) -> int | MissingValue :
50+ def general_statistics_ensure_python_integers (cls , v : Union [ int , MissingValue ] ) -> Union [ int , MissingValue ] :
5151 return v if isinstance (v , MissingValue ) else prepare_number_for_reporting (v , int )
5252
5353 @property
54- def percent_null (self ) -> float | MissingValue :
54+ def percent_null (self ) -> Union [ float , MissingValue ] :
5555 return (
5656 self .num_null
5757 if self ._is_missing_value (self .num_null )
5858 else prepare_number_for_reporting (100 * self .num_null / (self .num_records + EPSILON ), float )
5959 )
6060
6161 @property
62- def percent_unique (self ) -> float | MissingValue :
62+ def percent_unique (self ) -> Union [ float , MissingValue ] :
6363 return (
6464 self .num_unique
6565 if self ._is_missing_value (self .num_unique )
@@ -78,17 +78,17 @@ def _general_display_row(self) -> dict[str, str]:
7878 def create_report_row_data (self ) -> dict [str , str ]:
7979 return self ._general_display_row
8080
81- def _is_missing_value (self , v : float | int | MissingValue ) -> bool :
81+ def _is_missing_value (self , v : Union [ float , int , MissingValue ] ) -> bool :
8282 return v in set (MissingValue )
8383
8484
8585class LLMTextColumnStatistics (GeneralColumnStatistics ):
86- completion_tokens_mean : float | MissingValue
87- completion_tokens_median : float | MissingValue
88- completion_tokens_stddev : float | MissingValue
89- prompt_tokens_mean : float | MissingValue
90- prompt_tokens_median : float | MissingValue
91- prompt_tokens_stddev : float | MissingValue
86+ completion_tokens_mean : Union [ float , MissingValue ]
87+ completion_tokens_median : Union [ float , MissingValue ]
88+ completion_tokens_stddev : Union [ float , MissingValue ]
89+ prompt_tokens_mean : Union [ float , MissingValue ]
90+ prompt_tokens_median : Union [ float , MissingValue ]
91+ prompt_tokens_stddev : Union [ float , MissingValue ]
9292 column_type : Literal [DataDesignerColumnType .LLM_TEXT .value ] = DataDesignerColumnType .LLM_TEXT .value
9393
9494 @field_validator (
@@ -100,7 +100,7 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
100100 "prompt_tokens_stddev" ,
101101 mode = "before" ,
102102 )
103- def llm_column_ensure_python_floats (cls , v : float | int | MissingValue ) -> float | int | MissingValue :
103+ def llm_column_ensure_python_floats (cls , v : Union [ float , int , MissingValue ] ) -> Union [ float , int , MissingValue ] :
104104 return v if isinstance (v , MissingValue ) else prepare_number_for_reporting (v , float )
105105
106106 def create_report_row_data (self ) -> dict [str , Any ]:
@@ -136,7 +136,7 @@ class LLMJudgedColumnStatistics(LLMTextColumnStatistics):
136136class SamplerColumnStatistics (GeneralColumnStatistics ):
137137 sampler_type : SamplerType
138138 distribution_type : ColumnDistributionType
139- distribution : CategoricalDistribution | NumericalDistribution | MissingValue | None
139+ distribution : Optional [ Union [ CategoricalDistribution , NumericalDistribution , MissingValue ]]
140140 column_type : Literal [DataDesignerColumnType .SAMPLER .value ] = DataDesignerColumnType .SAMPLER .value
141141
142142 def create_report_row_data (self ) -> dict [str , str ]:
@@ -148,7 +148,7 @@ def create_report_row_data(self) -> dict[str, str]:
148148
149149class SeedDatasetColumnStatistics (GeneralColumnStatistics ):
150150 distribution_type : ColumnDistributionType
151- distribution : CategoricalDistribution | NumericalDistribution | MissingValue | None
151+ distribution : Optional [ Union [ CategoricalDistribution , NumericalDistribution , MissingValue ]]
152152 column_type : Literal [DataDesignerColumnType .SEED_DATASET .value ] = DataDesignerColumnType .SEED_DATASET .value
153153
154154 def create_report_row_data (self ) -> dict [str , str ]:
@@ -160,15 +160,15 @@ class ExpressionColumnStatistics(GeneralColumnStatistics):
160160
161161
162162class ValidationColumnStatistics (GeneralColumnStatistics ):
163- num_valid_records : int | MissingValue
163+ num_valid_records : Union [ int , MissingValue ]
164164 column_type : Literal [DataDesignerColumnType .VALIDATION .value ] = DataDesignerColumnType .VALIDATION .value
165165
166166 @field_validator ("num_valid_records" , mode = "before" )
167- def code_validation_column_ensure_python_integers (cls , v : int | MissingValue ) -> int | MissingValue :
167+ def code_validation_column_ensure_python_integers (cls , v : Union [ int , MissingValue ] ) -> Union [ int , MissingValue ] :
168168 return v if isinstance (v , MissingValue ) else prepare_number_for_reporting (v , int )
169169
170170 @property
171- def percent_valid (self ) -> float | MissingValue :
171+ def percent_valid (self ) -> Union [ float , MissingValue ] :
172172 return (
173173 self .num_valid_records
174174 if self ._is_missing_value (self .num_valid_records )
@@ -181,7 +181,7 @@ def create_report_row_data(self) -> dict[str, str]:
181181
182182
183183class CategoricalHistogramData (BaseModel ):
184- categories : list [float | int | str ]
184+ categories : list [Union [ float , int , str ] ]
185185 counts : list [int ]
186186
187187 @model_validator (mode = "after" )
@@ -198,12 +198,12 @@ def from_series(cls, series: Series) -> Self:
198198
199199
200200class CategoricalDistribution (BaseModel ):
201- most_common_value : str | int
202- least_common_value : str | int
201+ most_common_value : Union [ str , int ]
202+ least_common_value : Union [ str , int ]
203203 histogram : CategoricalHistogramData
204204
205205 @field_validator ("most_common_value" , "least_common_value" , mode = "before" )
206- def ensure_python_types (cls , v : str | int ) -> str | int :
206+ def ensure_python_types (cls , v : Union [ str , int ] ) -> Union [ str , int ] :
207207 return str (v ) if not is_int (v ) else prepare_number_for_reporting (v , int )
208208
209209 @classmethod
@@ -217,14 +217,14 @@ def from_series(cls, series: Series) -> Self:
217217
218218
219219class NumericalDistribution (BaseModel ):
220- min : float | int
221- max : float | int
220+ min : Union [ float , int ]
221+ max : Union [ float , int ]
222222 mean : float
223223 stddev : float
224224 median : float
225225
226226 @field_validator ("min" , "max" , "mean" , "stddev" , "median" , mode = "before" )
227- def ensure_python_types (cls , v : float | int ) -> float | int :
227+ def ensure_python_types (cls , v : Union [ float , int ] ) -> Union [ float , int ] :
228228 return prepare_number_for_reporting (v , int if is_int (v ) else float )
229229
230230 @classmethod
@@ -239,14 +239,16 @@ def from_series(cls, series: Series) -> Self:
239239
240240
241241ColumnStatisticsT : TypeAlias = Annotated [
242- GeneralColumnStatistics
243- | LLMTextColumnStatistics
244- | LLMCodeColumnStatistics
245- | LLMStructuredColumnStatistics
246- | LLMJudgedColumnStatistics
247- | SamplerColumnStatistics
248- | SeedDatasetColumnStatistics
249- | ValidationColumnStatistics
250- | ExpressionColumnStatistics ,
242+ Union [
243+ GeneralColumnStatistics ,
244+ LLMTextColumnStatistics ,
245+ LLMCodeColumnStatistics ,
246+ LLMStructuredColumnStatistics ,
247+ LLMJudgedColumnStatistics ,
248+ SamplerColumnStatistics ,
249+ SeedDatasetColumnStatistics ,
250+ ValidationColumnStatistics ,
251+ ExpressionColumnStatistics ,
252+ ],
251253 Field (discriminator = "column_type" ),
252254]
0 commit comments