Skip to content

Commit 55c21ef

Browse files
authored
seed dataset statistics limited to general stats (#32)
1 parent 42b089e commit 55c21ef

File tree

6 files changed

+1
-144
lines changed

6 files changed

+1
-144
lines changed

src/data_designer/config/analysis/column_statistics.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,8 @@ def create_report_row_data(self) -> dict[str, str]:
148148

149149

150150
class SeedDatasetColumnStatistics(GeneralColumnStatistics):
151-
distribution_type: ColumnDistributionType
152-
distribution: Optional[Union[CategoricalDistribution, NumericalDistribution, MissingValue]]
153151
column_type: Literal[DataDesignerColumnType.SEED_DATASET.value] = DataDesignerColumnType.SEED_DATASET.value
154152

155-
def create_report_row_data(self) -> dict[str, str]:
156-
return self._general_display_row
157-
158153

159154
class ExpressionColumnStatistics(GeneralColumnStatistics):
160155
column_type: Literal[DataDesignerColumnType.EXPRESSION.value] = DataDesignerColumnType.EXPRESSION.value

src/data_designer/engine/analysis/column_statistics.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
calculate_general_column_info,
2525
calculate_token_stats,
2626
calculate_validation_column_info,
27-
determine_column_distribution_type,
2827
)
2928

3029
logger = logging.getLogger(__name__)
@@ -105,18 +104,7 @@ def calculate_sampler_distribution(self) -> dict[str, Any]:
105104
)
106105

107106

108-
class SeedDatasetColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
109-
def calculate_seed_dataset_distribution(self) -> dict[str, Any]:
110-
dist_type = determine_column_distribution_type(self.df[self.column_config.name])
111-
make_dist = dist_type in [ColumnDistributionType.CATEGORICAL, ColumnDistributionType.NUMERICAL]
112-
return (
113-
calculate_column_distribution(self.column_config, self.df, dist_type)
114-
if make_dist
115-
else {
116-
"distribution_type": dist_type,
117-
"distribution": None,
118-
}
119-
)
107+
class SeedDatasetColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ...
120108

121109

122110
class ValidationColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):

src/data_designer/engine/analysis/utils/column_statistics_calculations.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
import numpy as np
1111
import pandas as pd
12-
from pandas import Series
13-
from pandas.core.dtypes.common import is_integer_dtype, is_numeric_dtype
1412
import pyarrow as pa
1513
import tiktoken
1614

@@ -180,67 +178,6 @@ def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
180178
return pyarrow_dtype_str
181179

182180

183-
def determine_column_distribution_type(column: Series) -> ColumnDistributionType:
184-
"""Based on the logic used by Gretel's SQS report to determine column data type."""
185-
if len(column) == 0:
186-
return ColumnDistributionType.OTHER
187-
188-
if isinstance(column.iloc[0], np.ndarray):
189-
return ColumnDistributionType.OTHER
190-
191-
if isinstance(column.iloc[0], dict):
192-
return ColumnDistributionType.OTHER
193-
194-
try:
195-
non_na_data = column.dropna()
196-
non_na_count = int(non_na_data.count())
197-
unique_count = int(non_na_data.nunique())
198-
except Exception:
199-
column = column.astype(str)
200-
non_na_data = column.dropna()
201-
non_na_count = int(non_na_data.count())
202-
unique_count = int(non_na_data.nunique())
203-
204-
if non_na_count == 0:
205-
return ColumnDistributionType.OTHER
206-
207-
if is_numeric_dtype(non_na_data.dtype):
208-
# Float values that are within 1e-8 of an integer are considered integers
209-
# Floats are considered numerical.
210-
if not np.allclose(non_na_data, non_na_data.astype(int), atol=1e-8):
211-
return ColumnDistributionType.NUMERICAL
212-
# We can visualize numeric data with histograms, but we will not use it for diversity calculations
213-
min_value = int(non_na_data.min())
214-
if unique_count <= 10 and min_value >= 0:
215-
return ColumnDistributionType.CATEGORICAL
216-
if unique_count == non_na_count and is_integer_dtype(non_na_data.dtype):
217-
# All unique integer values, potentially an ID column
218-
return ColumnDistributionType.OTHER
219-
return ColumnDistributionType.NUMERICAL
220-
221-
# Check if the column is a date-like column before checking for categorical or text columns.
222-
try:
223-
pd.to_datetime(non_na_data, format="%Y-%m-%d")
224-
return ColumnDistributionType.OTHER
225-
except Exception:
226-
pass
227-
228-
diff = non_na_count - unique_count
229-
diff_percent = diff / non_na_count
230-
if diff_percent >= 0.9 or (diff_percent >= 0.7 and len(non_na_data) <= 50):
231-
return ColumnDistributionType.CATEGORICAL
232-
233-
space_count = sum(str(entry).strip().count(" ") for entry in non_na_data)
234-
if space_count / non_na_count > TEXT_FIELD_AVG_SPACE_COUNT_THRESHOLD:
235-
return ColumnDistributionType.TEXT
236-
237-
if pd.api.types.is_string_dtype(non_na_data.dtype) and unique_count <= 10:
238-
# Check for string columns with a small number of unique values (categorical)
239-
return ColumnDistributionType.CATEGORICAL
240-
241-
return ColumnDistributionType.OTHER
242-
243-
244181
def ensure_hashable(x: Any) -> str:
245182
"""
246183
Makes a best effort turn known unhashable types to a hashable

tests/config/analysis/test_column_statistics.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
NumericalDistribution,
2020
SamplerColumnStatistics,
2121
SamplerType,
22-
SeedDatasetColumnStatistics,
2322
ValidationColumnStatistics,
2423
)
2524

@@ -190,22 +189,6 @@ def test_sampler_column_statistics(stub_general_stats_args_with_valid_values, st
190189
}
191190

192191

193-
def test_seed_dataset_column_statistics(stub_general_stats_args_with_valid_values, stub_categorical_distribution):
194-
seed_dataset_column_statistics = SeedDatasetColumnStatistics(
195-
**stub_general_stats_args_with_valid_values,
196-
distribution_type=ColumnDistributionType.CATEGORICAL,
197-
distribution=stub_categorical_distribution,
198-
)
199-
assert seed_dataset_column_statistics.column_type == "seed-dataset"
200-
assert seed_dataset_column_statistics.distribution_type == ColumnDistributionType.CATEGORICAL
201-
assert isinstance(seed_dataset_column_statistics.distribution, CategoricalDistribution)
202-
assert seed_dataset_column_statistics.create_report_row_data() == {
203-
"column name": "test",
204-
"number unique values": "10 (10.0%)",
205-
"data type": "str",
206-
}
207-
208-
209192
def test_validation_column_statistics_with_missing_values(stub_general_stats_args_with_missing_values):
210193
validation_column_statistics = ValidationColumnStatistics(
211194
**stub_general_stats_args_with_missing_values,

tests/engine/analysis/test_column_statistics_calculator.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,3 @@ def test_sampler_column_statistics(stub_df, column_configs):
7171
assert isinstance(stats.distribution.mean, float)
7272
assert isinstance(stats.distribution.stddev, float)
7373
assert isinstance(stats.distribution.median, float)
74-
75-
76-
def test_seed_dataset_column_statistics(stub_df, column_configs):
77-
for column_config in column_configs:
78-
if column_config.column_type == DataDesignerColumnType.SEED_DATASET:
79-
column_config_with_df = ColumnConfigWithDataFrame(column_config=column_config, df=stub_df)
80-
stats = get_column_statistics_calculator(column_config.column_type)(
81-
column_config_with_df=column_config_with_df
82-
).calculate()
83-
assert stats.column_name == column_config.name
84-
assert stats.column_type == column_config.column_type
85-
if stats.distribution_type == ColumnDistributionType.CATEGORICAL:
86-
assert hasattr(stats.distribution, "histogram")
87-
assert isinstance(stats.distribution.most_common_value, (int, str))
88-
assert isinstance(stats.distribution.least_common_value, (int, str))
89-
elif stats.distribution_type == ColumnDistributionType.NUMERICAL:
90-
assert not hasattr(stats.distribution, "histogram")
91-
assert isinstance(stats.distribution.min, (int, float))
92-
assert isinstance(stats.distribution.max, (int, float))
93-
assert isinstance(stats.distribution.mean, float)
94-
assert isinstance(stats.distribution.stddev, float)
95-
assert isinstance(stats.distribution.median, float)

tests/engine/analysis/utils/test_column_statistics_calculations.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
calculate_prompt_token_stats,
2626
calculate_validation_column_info,
2727
convert_pyarrow_dtype_to_simple_dtype,
28-
determine_column_distribution_type,
2928
ensure_boolean,
3029
ensure_hashable,
3130
)
@@ -244,29 +243,6 @@ def test_convert_pyarrow_dtype_to_simple_dtype():
244243
assert convert_pyarrow_dtype_to_simple_dtype(unknown_type) == str(unknown_type)
245244

246245

247-
def test_determine_column_distribution_type():
248-
assert determine_column_distribution_type(pd.Series([])) == ColumnDistributionType.OTHER
249-
assert determine_column_distribution_type(pd.Series([{"a": 1}, {"b": 2}])) == ColumnDistributionType.OTHER
250-
assert (
251-
determine_column_distribution_type(pd.Series([np.array([1, 2, 3]), np.array([4, 5, 6])]))
252-
== ColumnDistributionType.OTHER
253-
)
254-
assert determine_column_distribution_type(pd.Series([1, 2, 1, 3, 1, 2])) == ColumnDistributionType.CATEGORICAL
255-
assert determine_column_distribution_type(pd.Series([1.1, 2.2, 3.3, 4.4, 5.5])) == ColumnDistributionType.NUMERICAL
256-
assert (
257-
determine_column_distribution_type(pd.Series(["A", "A", "C", "C", "A", "B"]))
258-
== ColumnDistributionType.CATEGORICAL
259-
)
260-
assert (
261-
determine_column_distribution_type(pd.Series(["This is a long text", "Another long text with spaces"]))
262-
== ColumnDistributionType.TEXT
263-
)
264-
assert (
265-
determine_column_distribution_type(pd.Series(["2023-01-01", "2023-01-02", "2023-01-03"]))
266-
== ColumnDistributionType.OTHER
267-
)
268-
269-
270246
def test_prepare_number_for_reporting():
271247
assert prepare_number_for_reporting(5, int) == 5
272248
assert isinstance(prepare_number_for_reporting(5, int), int)

0 commit comments

Comments
 (0)