Skip to content

Commit 8b20315

Browse files
committed
Unify Arrow-based type inference for CSV and Parquet
1 parent 70865ac commit 8b20315

File tree

10 files changed

+857
-900
lines changed

10 files changed

+857
-900
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ authors = [
2020
]
2121

2222
# Core runtime dependencies needed for the tool to function.
23+
# Note: pandas is not a direct dependency — it is available transitively via wfdb.
2324
dependencies = [
2425
"typer >= 0.9.0",
25-
"mlcroissant >= 1.0.0",
26-
"pandas >= 1.3.0",
26+
"mlcroissant >= 1.0.20",
2727
"rich >= 13.0.0",
2828
"wfdb >= 4.0.0",
2929
"pyarrow >= 15.0.0",
Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
"""CSV file handler for tabular data processing."""
22

33
from pathlib import Path
4-
import pandas as pd
4+
5+
import pyarrow as pa
6+
import pyarrow.csv as pa_csv
57

68
from croissant_maker.handlers.base_handler import FileTypeHandler
7-
from croissant_maker.handlers.utils import analyze_data_sample, compute_file_hash
9+
from croissant_maker.handlers.utils import (
10+
compute_file_hash,
11+
infer_column_types_from_arrow_schema,
12+
)
813

914

1015
class CSVHandler(FileTypeHandler):
@@ -16,11 +21,27 @@ class CSVHandler(FileTypeHandler):
1621
- Gzip-compressed CSV files (.csv.gz)
1722
- Bzip2-compressed CSV files (.csv.bz2)
1823
- XZ-compressed CSV files (.csv.xz)
19-
- Automatic column type detection using pandas
24+
- Automatic column type detection using PyArrow
2025
- SHA256 hash computation for file integrity
21-
- Sample data extraction for preview
26+
27+
Uses PyArrow's CSV reader which:
28+
- Auto-detects compressed formats from filename extension
29+
- Infers precise types (timestamp[s], date32, int64, float64, etc.)
30+
- Reads multi-threaded by default for performance
2231
"""
2332

33+
# Common timestamp formats for medical/clinical data beyond ISO-8601.
34+
# PyArrow uses ISO8601 by default; these cover additional patterns found
35+
# in datasets like MIMIC, eICU, and OMOP.
36+
_TIMESTAMP_PARSERS = [
37+
pa_csv.ISO8601,
38+
"%Y-%m-%d %H:%M:%S",
39+
"%m/%d/%Y",
40+
"%d/%m/%Y",
41+
"%m/%d/%Y %H:%M:%S",
42+
"%Y-%m-%dT%H:%M:%S",
43+
]
44+
2445
def can_handle(self, file_path: Path) -> bool:
2546
"""
2647
Check if the file is a CSV or compressed CSV file.
@@ -43,19 +64,17 @@ def extract_metadata(self, file_path: Path) -> dict:
4364
"""
4465
Extract comprehensive metadata from a CSV file.
4566
46-
Reads a sample of the CSV file to infer column types, extracts
47-
file statistics, computes integrity hashes, and prepares all
48-
metadata needed for Croissant generation.
67+
Uses PyArrow to read the CSV with automatic type inference,
68+
including timestamp detection and precise numeric types.
4969
5070
Args:
5171
file_path: Path to the CSV file
5272
5373
Returns:
5474
Dictionary containing:
5575
- Basic file info (path, name, size, hash)
56-
- Format information (encoding, compression)
76+
- Format information (encoding)
5777
- Data structure (columns, types, row count)
58-
- Sample data for preview
5978
6079
Raises:
6180
ValueError: If the CSV file cannot be read or processed
@@ -64,45 +83,46 @@ def extract_metadata(self, file_path: Path) -> dict:
6483
if not file_path.exists():
6584
raise FileNotFoundError(f"CSV file not found: {file_path}")
6685

86+
# Parse CSV — only this call needs error translation
6787
try:
68-
# Read a sample for type inference (1000 rows for good accuracy)
69-
df = pd.read_csv(file_path, nrows=1000)
70-
71-
if df.empty:
72-
raise ValueError(f"CSV file is empty: {file_path}")
73-
74-
# Extract file properties
75-
file_size = file_path.stat().st_size
76-
sha256_hash = compute_file_hash(file_path)
77-
78-
# Analyze the data structure and types
79-
data_analysis = analyze_data_sample(df)
80-
81-
# Determine encoding format based on file extension
82-
name_lower = file_path.name.lower()
83-
if name_lower.endswith(".csv.gz"):
84-
encoding_format = "application/gzip"
85-
elif name_lower.endswith(".csv.bz2"):
86-
encoding_format = "application/x-bzip2"
87-
elif name_lower.endswith(".csv.xz"):
88-
encoding_format = "application/x-xz"
89-
else:
90-
encoding_format = "text/csv"
91-
92-
return {
93-
"file_path": str(file_path),
94-
"file_name": file_path.name,
95-
"file_size": file_size,
96-
"sha256": sha256_hash,
97-
"encoding_format": encoding_format,
98-
**data_analysis, # Includes column_types, num_rows, columns, sample_data
99-
}
100-
101-
except pd.errors.EmptyDataError:
102-
raise ValueError(f"CSV file contains no data: {file_path}")
103-
except pd.errors.ParserError as e:
88+
convert_options = pa_csv.ConvertOptions(
89+
timestamp_parsers=self._TIMESTAMP_PARSERS,
90+
)
91+
table = pa_csv.read_csv(str(file_path), convert_options=convert_options)
92+
except pa.lib.ArrowInvalid as e:
10493
raise ValueError(f"Failed to parse CSV file {file_path}: {e}")
10594
except UnicodeDecodeError as e:
10695
raise ValueError(f"Encoding error in CSV file {file_path}: {e}")
107-
except Exception as e:
108-
raise ValueError(f"Failed to process CSV file {file_path}: {e}")
96+
97+
if table.num_rows == 0:
98+
raise ValueError(f"CSV file is empty: {file_path}")
99+
100+
# Infer types from the Arrow schema (shared with Parquet handler)
101+
column_types = infer_column_types_from_arrow_schema(table.schema)
102+
103+
# Extract file properties
104+
file_size = file_path.stat().st_size
105+
sha256_hash = compute_file_hash(file_path)
106+
107+
# Determine encoding format based on file extension
108+
name_lower = file_path.name.lower()
109+
if name_lower.endswith(".csv.gz"):
110+
encoding_format = "application/gzip"
111+
elif name_lower.endswith(".csv.bz2"):
112+
encoding_format = "application/x-bzip2"
113+
elif name_lower.endswith(".csv.xz"):
114+
encoding_format = "application/x-xz"
115+
else:
116+
encoding_format = "text/csv"
117+
118+
return {
119+
"file_path": str(file_path),
120+
"file_name": file_path.name,
121+
"file_size": file_size,
122+
"sha256": sha256_hash,
123+
"encoding_format": encoding_format,
124+
"column_types": column_types,
125+
"num_rows": table.num_rows,
126+
"num_columns": table.num_columns,
127+
"columns": table.column_names,
128+
}
Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
"""Parquet file handler for tabular event streams (e.g., MEDS)."""
22

33
from pathlib import Path
4-
from typing import Dict
54

6-
from croissant_maker.handlers.base_handler import FileTypeHandler
7-
from croissant_maker.handlers.utils import compute_file_hash
85
from pyarrow.parquet import ParquetFile
9-
import pyarrow.types as patypes
6+
7+
from croissant_maker.handlers.base_handler import FileTypeHandler
8+
from croissant_maker.handlers.utils import (
9+
compute_file_hash,
10+
infer_column_types_from_arrow_schema,
11+
)
1012

1113

1214
class ParquetHandler(FileTypeHandler):
1315
"""
1416
Handler for Parquet files (.parquet) with schema-based type inference.
1517
1618
- Uses pyarrow to read schema and row count without loading full data
17-
- Emits Croissant-compatible column types
19+
- Emits Croissant-compatible column types via shared map_arrow_type()
1820
- Computes SHA256 for reproducibility
1921
- Keeps memory usage minimal (schema-only)
2022
"""
@@ -32,13 +34,9 @@ def extract_metadata(self, file_path: Path) -> dict:
3234
schema = pq.schema_arrow
3335
num_rows = pq.metadata.num_rows if pq.metadata is not None else 0
3436

35-
column_types: Dict[str, str] = {}
36-
columns = []
37-
for field in schema:
38-
columns.append(field.name)
39-
column_types[field.name] = self._map_arrow_type_to_croissant(
40-
field.type, patypes
41-
)
37+
# Use the shared Arrow type mapper (same as CSV handler)
38+
column_types = infer_column_types_from_arrow_schema(schema)
39+
columns = [field.name for field in schema]
4240

4341
file_size = file_path.stat().st_size
4442
sha256_hash = compute_file_hash(file_path)
@@ -56,26 +54,3 @@ def extract_metadata(self, file_path: Path) -> dict:
5654
}
5755
except Exception as e:
5856
raise ValueError(f"Failed to process Parquet file {file_path}: {e}") from e
59-
60-
@staticmethod
61-
def _map_arrow_type_to_croissant(arrow_type, patypes) -> str:
62-
"""Map pyarrow types to Croissant schema.org data types."""
63-
try:
64-
if patypes.is_integer(arrow_type):
65-
return "sc:Integer"
66-
if patypes.is_floating(arrow_type) or patypes.is_decimal(arrow_type):
67-
return "sc:Float"
68-
if patypes.is_boolean(arrow_type):
69-
return "sc:Boolean"
70-
if patypes.is_timestamp(arrow_type):
71-
return "sc:Date"
72-
if patypes.is_date(arrow_type):
73-
return "sc:Date"
74-
if patypes.is_string(arrow_type) or patypes.is_large_string(arrow_type):
75-
return "sc:Text"
76-
if patypes.is_binary(arrow_type) or patypes.is_large_binary(arrow_type):
77-
return "sc:Text"
78-
except Exception:
79-
# Fallback to text for any exotic or extension types
80-
pass
81-
return "sc:Text"

0 commit comments

Comments
 (0)