Skip to content

Commit b77e803

Browse files
committed
add unit tests
1 parent 1859127 commit b77e803

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

src/data_designer/engine/analysis/utils/column_statistics_calculations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def convert_to_simple_dtype(dtype: str) -> str:
175175
return "float"
176176
if "float" in dtype:
177177
return "float"
178-
if "string" in dtype:
178+
if "string" in dtype or dtype == "str":
179179
return "string"
180180
if "timestamp" in dtype:
181181
return "timestamp"

tests/engine/analysis/utils/test_column_statistics_calculations.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,54 @@ def test_ensure_boolean():
297297
ensure_boolean(2)
298298
with pytest.raises(ValueError):
299299
ensure_boolean(1.5)
300+
301+
302+
def test_calculate_general_column_info_dtype_detection():
303+
"""Test dtype detection with PyArrow backend (preferred path)."""
304+
df_pyarrow = pa.Table.from_pydict(
305+
{"int_col": [1, 2, 3], "str_col": ["a", "b", "c"], "float_col": [1.1, 2.2, 3.3]}
306+
).to_pandas(types_mapper=pd.ArrowDtype)
307+
308+
result = calculate_general_column_info("int_col", df_pyarrow)
309+
assert result["simple_dtype"] == "int"
310+
assert result["pyarrow_dtype"] == "int64"
311+
312+
result = calculate_general_column_info("str_col", df_pyarrow)
313+
assert result["simple_dtype"] == "string"
314+
assert "string" in result["pyarrow_dtype"]
315+
316+
result = calculate_general_column_info("float_col", df_pyarrow)
317+
assert result["simple_dtype"] == "float"
318+
assert result["pyarrow_dtype"] == "double"
319+
320+
321+
def test_calculate_general_column_info_dtype_detection_fallback():
322+
"""Test dtype detection fallback when PyArrow backend unavailable (mixed types)."""
323+
df_mixed = pd.DataFrame({"mixed_col": [1, "two", 3.0, "four", 5]})
324+
325+
result = calculate_general_column_info("mixed_col", df_mixed)
326+
assert result["simple_dtype"] == "int"
327+
assert result["pyarrow_dtype"] == "n/a"
328+
assert result["num_records"] == 5
329+
assert result["num_unique"] == 5
330+
331+
332+
def test_calculate_general_column_info_edge_cases():
333+
"""Test edge cases: nulls, empty columns, and all-null columns."""
334+
df_with_nulls = pd.DataFrame({"col_with_nulls": [None, None, 42.0, 43.0, 44.0]})
335+
result = calculate_general_column_info("col_with_nulls", df_with_nulls)
336+
assert result["simple_dtype"] == "float"
337+
assert result["num_null"] == 2
338+
assert result["num_unique"] == 3
339+
340+
df_all_nulls = pd.DataFrame({"all_nulls": [None, None, None]})
341+
result = calculate_general_column_info("all_nulls", df_all_nulls)
342+
assert result["simple_dtype"] == MissingValue.CALCULATION_FAILED
343+
assert result["num_null"] == 3
344+
assert result["num_unique"] == 0
345+
346+
df_empty = pd.DataFrame({"empty_col": []})
347+
result = calculate_general_column_info("empty_col", df_empty)
348+
assert result["num_records"] == 0
349+
assert result["num_null"] == 0
350+
assert result["simple_dtype"] == MissingValue.CALCULATION_FAILED

0 commit comments

Comments
 (0)