1818 MissingValue ,
1919 NumericalDistribution ,
2020)
21- from data_designer .config .column_configs import (
22- LLMTextColumnConfig ,
23- SingleColumnConfig ,
24- ValidationColumnConfig ,
25- )
21+ from data_designer .config .column_configs import LLMTextColumnConfig
2622from data_designer .engine .column_generators .generators .llm_generators import (
2723 PromptType ,
2824 RecordBasedPromptRenderer ,
3935
4036
4137def calculate_column_distribution (
42- column_config : SingleColumnConfig , df : pd .DataFrame , distribution_type : ColumnDistributionType
38+ column_name : str , df : pd .DataFrame , distribution_type : ColumnDistributionType
4339) -> dict [str , CategoricalDistribution | NumericalDistribution | MissingValue | None ]:
4440 distribution_type = ColumnDistributionType (distribution_type )
4541 try :
4642 if distribution_type == ColumnDistributionType .CATEGORICAL :
4743 return {
4844 "distribution_type" : ColumnDistributionType .CATEGORICAL ,
49- "distribution" : CategoricalDistribution .from_series (df [column_config . name ]),
45+ "distribution" : CategoricalDistribution .from_series (df [column_name ]),
5046 }
5147
5248 if distribution_type == ColumnDistributionType .NUMERICAL :
5349 return {
5450 "distribution_type" : ColumnDistributionType .NUMERICAL ,
55- "distribution" : NumericalDistribution .from_series (df [column_config . name ]),
51+ "distribution" : NumericalDistribution .from_series (df [column_name ]),
5652 }
5753 except Exception as e :
58- logger .warning (f"{ WARNING_PREFIX } failed to calculate column distribution for '{ column_config . name } ' { e } " )
54+ logger .warning (f"{ WARNING_PREFIX } failed to calculate column distribution for '{ column_name } ' { e } " )
5955 return {
6056 "distribution_type" : ColumnDistributionType .UNKNOWN ,
6157 "distribution" : MissingValue .CALCULATION_FAILED ,
6258 }
6359
6460
65- def calculate_general_column_info (column_config : SingleColumnConfig , df : pd .DataFrame ) -> dict [str , Any ]:
61+ def calculate_general_column_info (column_name : str , df : pd .DataFrame ) -> dict [str , Any ]:
6662 try :
67- _df = pd .DataFrame (df [column_config .name ].apply (ensure_hashable ))
63+ _df = pd .DataFrame (df [column_name ].apply (ensure_hashable ))
64+
65+ if has_pyarrow_backend (df ):
66+ pyarrow_dtype = str (df [column_name ].dtype .pyarrow_dtype )
67+ simple_dtype = convert_pyarrow_dtype_to_simple_dtype (df [column_name ].dtype .pyarrow_dtype )
68+ else :
69+ # We do not log a warning at the column-level because it would be too noisy.
70+ # However, there is a logged warning at the dataset-profiler level.
71+ try :
72+ simple_dtype = get_column_data_type_from_first_non_null_value (column_name , df )
73+ except Exception :
74+ simple_dtype = MissingValue .CALCULATION_FAILED
75+ pyarrow_dtype = "n/a"
76+
6877 return {
69- "pyarrow_dtype" : str ( df [ column_config . name ]. dtype . pyarrow_dtype ) ,
70- "simple_dtype" : convert_pyarrow_dtype_to_simple_dtype ( df [ column_config . name ]. dtype . pyarrow_dtype ) ,
71- "num_records" : len (_df [column_config . name ]),
72- "num_null" : _df [column_config . name ].isnull ().sum (),
73- "num_unique" : _df [column_config . name ].nunique (),
78+ "pyarrow_dtype" : pyarrow_dtype ,
79+ "simple_dtype" : simple_dtype ,
80+ "num_records" : len (_df [column_name ]),
81+ "num_null" : _df [column_name ].isnull ().sum (),
82+ "num_unique" : _df [column_name ].nunique (),
7483 }
7584 except Exception as e :
76- logger .warning (f"{ WARNING_PREFIX } failed to calculate general column info for '{ column_config . name } ': { e } " )
85+ logger .warning (f"{ WARNING_PREFIX } failed to calculate general column info for '{ column_name } ': { e } " )
7786 return {
7887 "pyarrow_dtype" : MissingValue .CALCULATION_FAILED ,
7988 "simple_dtype" : MissingValue .CALCULATION_FAILED ,
@@ -115,11 +124,9 @@ def calculate_prompt_token_stats(
115124 }
116125
117126
118- def calculate_completion_token_stats (
119- column_config : LLMTextColumnConfig , df : pd .DataFrame
120- ) -> dict [str , float | MissingValue ]:
127+ def calculate_completion_token_stats (column_name : str , df : pd .DataFrame ) -> dict [str , float | MissingValue ]:
121128 try :
122- tokens_per_record = df [column_config . name ].apply (
129+ tokens_per_record = df [column_name ].apply (
123130 lambda value : len (TOKENIZER .encode (str (value ), disallowed_special = ()))
124131 )
125132 return {
@@ -128,9 +135,7 @@ def calculate_completion_token_stats(
128135 "completion_tokens_stddev" : tokens_per_record .std (),
129136 }
130137 except Exception as e :
131- logger .warning (
132- f"{ WARNING_PREFIX } failed to calculate completion token stats for column { column_config .name } : { e } "
133- )
138+ logger .warning (f"{ WARNING_PREFIX } failed to calculate completion token stats for column { column_name } : { e } " )
134139 return {
135140 "completion_tokens_mean" : MissingValue .CALCULATION_FAILED ,
136141 "completion_tokens_median" : MissingValue .CALCULATION_FAILED ,
@@ -141,16 +146,16 @@ def calculate_completion_token_stats(
141146def calculate_token_stats (column_config : LLMTextColumnConfig , df : pd .DataFrame ) -> dict [str , float | MissingValue ]:
142147 return {
143148 ** calculate_prompt_token_stats (column_config , df ),
144- ** calculate_completion_token_stats (column_config , df ),
149+ ** calculate_completion_token_stats (column_config . name , df ),
145150 }
146151
147152
148- def calculate_validation_column_info (column_config : ValidationColumnConfig , df : pd .DataFrame ) -> dict [str , Any ]:
153+ def calculate_validation_column_info (column_name : str , df : pd .DataFrame ) -> dict [str , Any ]:
149154 try :
150- return {"num_valid_records" : df [column_config . name ].apply (lambda x : ensure_boolean (x ["is_valid" ])).sum ()}
155+ return {"num_valid_records" : df [column_name ].apply (lambda x : ensure_boolean (x ["is_valid" ])).sum ()}
151156 except Exception as e :
152157 logger .warning (
153- f"{ WARNING_PREFIX } failed to calculate code validation column info for column { column_config . name } : { e } "
158+ f"{ WARNING_PREFIX } failed to calculate code validation column info for column { column_name } : { e } "
154159 )
155160 return {"num_valid_records" : MissingValue .CALCULATION_FAILED }
156161
@@ -160,22 +165,33 @@ def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
160165 return f"list[{ convert_pyarrow_dtype_to_simple_dtype (pyarrow_dtype .value_type )} ]"
161166 if isinstance (pyarrow_dtype , pa .StructType ):
162167 return "dict"
163- pyarrow_dtype_str = str (pyarrow_dtype )
164- if "int" in pyarrow_dtype_str :
168+ return convert_to_simple_dtype (str (pyarrow_dtype ))
169+
170+
171+ def convert_to_simple_dtype (dtype : str ) -> str :
172+ if "int" in dtype :
165173 return "int"
166- if "double" in pyarrow_dtype_str :
174+ if "double" in dtype :
167175 return "float"
168- if "float" in pyarrow_dtype_str :
176+ if "float" in dtype :
169177 return "float"
170- if "string " in pyarrow_dtype_str :
178+ if "str " in dtype :
171179 return "string"
172- if "timestamp" in pyarrow_dtype_str :
180+ if "timestamp" in dtype :
173181 return "timestamp"
174- if "time" in pyarrow_dtype_str :
182+ if "time" in dtype :
175183 return "time"
176- if "date" in pyarrow_dtype_str :
184+ if "date" in dtype :
177185 return "date"
178- return pyarrow_dtype_str
186+ return dtype
187+
188+
189+ def get_column_data_type_from_first_non_null_value (column_name : str , df : pd .DataFrame ) -> str :
190+ df_no_nulls = df [column_name ].dropna ()
191+ if len (df_no_nulls ) == 0 :
192+ return MissingValue .CALCULATION_FAILED
193+ dtype = type (df_no_nulls .iloc [0 ]).__name__
194+ return convert_to_simple_dtype (dtype )
179195
180196
181197def ensure_hashable (x : Any ) -> str :
@@ -207,3 +223,7 @@ def ensure_boolean(v: bool | str | int | None) -> bool:
207223 if v is None :
208224 return False
209225 raise ValueError (f"Invalid boolean value: { v } " )
226+
227+
228+ def has_pyarrow_backend (df : pd .DataFrame ) -> bool :
229+ return all (isinstance (dtype , pd .ArrowDtype ) for dtype in df .dtypes )
0 commit comments