diff --git a/dingo/config/__init__.py b/dingo/config/__init__.py index 1c0cbccd..810d254f 100644 --- a/dingo/config/__init__.py +++ b/dingo/config/__init__.py @@ -1,2 +1,2 @@ -from dingo.config.input_args import (DatasetArgs, DatasetExcelArgs, DatasetFieldArgs, DatasetHFConfigArgs, DatasetS3ConfigArgs, DatasetSqlArgs, EvalPipline, EvalPiplineConfig, # noqa E402. - EvaluatorLLMArgs, EvaluatorRuleArgs, ExecutorArgs, ExecutorResultSaveArgs, InputArgs) +from dingo.config.input_args import (DatasetArgs, DatasetCsvArgs, DatasetExcelArgs, DatasetFieldArgs, DatasetHFConfigArgs, DatasetS3ConfigArgs, DatasetSqlArgs, EvalPipline, # noqa E402. + EvalPiplineConfig, EvaluatorLLMArgs, EvaluatorRuleArgs, ExecutorArgs, ExecutorResultSaveArgs, InputArgs) diff --git a/dingo/config/input_args.py b/dingo/config/input_args.py index 77758a1f..92097be6 100644 --- a/dingo/config/input_args.py +++ b/dingo/config/input_args.py @@ -32,6 +32,14 @@ class DatasetExcelArgs(BaseModel): has_header: bool = True # 第一行是否为列名,False 则使用列序号作为列名 +class DatasetCsvArgs(BaseModel): + has_header: bool = True # 第一行是否为列名,False 则使用 column_x 作为列名 + encoding: str = 'utf-8' # 文件编码,默认 utf-8,支持 gbk, gb2312, latin1 等 + dialect: str = 'excel' # CSV 格式方言:excel(默认), excel-tab, unix 等 + delimiter: str | None = None # 分隔符,None 表示根据 dialect 自动选择 + quotechar: str = '"' # 引号字符,默认双引号 + + class DatasetFieldArgs(BaseModel): id: str = '' prompt: str = '' @@ -49,6 +57,7 @@ class DatasetArgs(BaseModel): s3_config: DatasetS3ConfigArgs = DatasetS3ConfigArgs() sql_config: DatasetSqlArgs = DatasetSqlArgs() excel_config: DatasetExcelArgs = DatasetExcelArgs() + csv_config: DatasetCsvArgs = DatasetCsvArgs() class ExecutorResultSaveArgs(BaseModel): diff --git a/dingo/data/converter/base.py b/dingo/data/converter/base.py index 8c14fa01..da1feb69 100644 --- a/dingo/data/converter/base.py +++ b/dingo/data/converter/base.py @@ -280,6 +280,25 @@ def _convert(raw: Union[str, Dict]): return _convert +@BaseConverter.register("csv") +class CsvConverter(BaseConverter): + """CSV file converter.""" + + def __init__(self): + super().__init__() + + @classmethod + def convertor(cls, input_args: InputArgs) -> Callable: + def _convert(raw: Union[str, Dict]): + j = raw + if isinstance(raw, str): + j = json.loads(raw) + data_dict = j + return Data(**data_dict) + + return _convert + + @BaseConverter.register("listjson") class ListJsonConverter(BaseConverter): """List json file converter.""" diff --git a/dingo/data/datasource/local.py b/dingo/data/datasource/local.py index 6dcc4289..b4d5e6cd 100644 --- a/dingo/data/datasource/local.py +++ b/dingo/data/datasource/local.py @@ -142,6 +142,94 @@ def _load_excel_file_xlsx(self, path: str) -> Generator[str, None, None]: if wb: wb.close() + def _load_csv_file(self, path: str) -> Generator[str, None, None]: + """ + Load a CSV file and return its contents row by row as JSON strings. + Supports streaming for large files, different encodings, and various CSV formats. + + Args: + path (str): The path to the CSV file. + + Returns: + Generator[str]: Each row as a JSON string with header keys. + """ + import csv + + # 获取 CSV 配置 + has_header = self.input_args.dataset.csv_config.has_header + encoding = self.input_args.dataset.csv_config.encoding + dialect = self.input_args.dataset.csv_config.dialect + delimiter = self.input_args.dataset.csv_config.delimiter + quotechar = self.input_args.dataset.csv_config.quotechar + + try: + # 尝试使用指定的编码打开文件 + with open(path, 'r', encoding=encoding, newline='') as csvfile: + # 设置 CSV reader 参数 + reader_kwargs = { + 'dialect': dialect, + 'quotechar': quotechar, + } + + # 如果指定了自定义分隔符,覆盖 dialect 的默认值 + if delimiter is not None: + reader_kwargs['delimiter'] = delimiter + + # 创建 CSV reader(流式读取) + csv_reader = csv.reader(csvfile, **reader_kwargs) + + # 处理标题行 + headers = None + # first_row_data = None + + try: + first_row = next(csv_reader) + except StopIteration: + raise RuntimeError(f'CSV file "{path}" is empty') + + if has_header: + # The first row is the header + headers = [str(h).strip() if h else f'column_{i}' for i, h in enumerate(first_row)] + data_rows = csv_reader + else: + # Generate headers and treat the first row as data + from itertools import chain + headers = [f'column_{i}' for i in range(len(first_row))] + data_rows = chain([first_row], csv_reader) + + # Process all data rows in a single loop + for row in data_rows: + # Skip empty rows + if not row or all(not cell.strip() for cell in row): + continue + + # Combine row data with headers into a dictionary, handling rows with fewer columns + row_dict = { + header: (row[i].strip() if row[i] else "") if i < len(row) else "" + for i, header in enumerate(headers) + } + + # Yield the JSON string + yield json.dumps(row_dict, ensure_ascii=False) + '\n' + + except UnicodeDecodeError as e: + # 编码错误提示 + raise RuntimeError( + f'Failed to read CSV file "{path}" with encoding "{encoding}": {str(e)}. ' + f'Please try a different encoding (e.g., "gbk", "gb2312", "latin1", "iso-8859-1").' + ) + except csv.Error as e: + # CSV 格式错误 + raise RuntimeError( + f'Failed to parse CSV file "{path}": {str(e)}. ' + f'Current dialect: "{dialect}". You may need to adjust the dialect or delimiter parameter.' + ) + except Exception as e: + raise RuntimeError( + f'Failed to read CSV file "{path}": {str(e)}. ' + f'Please ensure the file is a valid CSV file.' + ) + def _load_excel_file_xls(self, path: str) -> Generator[str, None, None]: """ Load an .xls Excel file and return its contents row by row as JSON strings. @@ -241,8 +329,13 @@ def _load_local_file(self) -> Generator[str, None, None]: by_line = self.input_args.dataset.format not in ["json", "listjson"] for f in f_list: + # Check if file is CSV + if f.endswith('.csv'): + if self.input_args.dataset.format != 'csv': + raise RuntimeError(f'CSV file "{f}" is not supported. Please set dataset.format to "csv" to read CSV files.') + yield from self._load_csv_file(f) # Check if file is Excel - if f.endswith('.xlsx'): + elif f.endswith('.xlsx'): if self.input_args.dataset.format != 'excel': raise RuntimeError(f'Excel file "{f}" is not supported. Please set dataset.format to "excel" to read Excel files.') yield from self._load_excel_file_xlsx(f) @@ -278,7 +371,7 @@ def _load_local_file(self) -> Generator[str, None, None]: except UnicodeDecodeError as decode_error: raise RuntimeError( f'Failed to read file "{f}": Unsupported file format or encoding. ' - f'Dingo only supports UTF-8 text files (.jsonl, .json, .txt), Excel files (.xlsx, .xls) and .gz compressed text files. ' + f'Dingo only supports UTF-8 text files (.jsonl, .json, .txt), CSV files (.csv), Excel files (.xlsx, .xls) and .gz compressed text files. ' f'Original error: {str(decode_error)}' ) except Exception as e: diff --git a/dingo/exec/__init__.py b/dingo/exec/__init__.py index 7ef64f1c..f2a554f1 100644 --- a/dingo/exec/__init__.py +++ b/dingo/exec/__init__.py @@ -1,3 +1,4 @@ +from dingo.exec.base import ExecProto, Executor # noqa E402. from dingo.exec.local import LocalExecutor # noqa E402. from dingo.utils import log @@ -6,5 +7,3 @@ except Exception as e: log.warning("Spark Executor not imported. Open debug log for more details.") log.debug(str(e)) - -from dingo.exec.base import ExecProto, Executor # noqa E402. diff --git a/dingo/exec/local.py b/dingo/exec/local.py index 628b2732..c60eb5f3 100644 --- a/dingo/exec/local.py +++ b/dingo/exec/local.py @@ -115,6 +115,7 @@ def execute(self) -> SummaryModel: self.summary.type_ratio[field_key] = {} # 遍历 List[EvalDetail],同时收集指标分数和标签 + label_set = set() for eval_detail in eval_detail_list: # 收集指标分数(按 field_key 分组) if eval_detail.score is not None and eval_detail.metric: @@ -123,8 +124,11 @@ def execute(self) -> SummaryModel: # 收集标签统计 label_list = eval_detail.label if eval_detail.label else [] for label in label_list: - self.summary.type_ratio[field_key].setdefault(label, 0) - self.summary.type_ratio[field_key][label] += 1 + label_set.add(label) + + for label in label_set: + self.summary.type_ratio[field_key].setdefault(label, 0) + self.summary.type_ratio[field_key][label] += 1 if result_info.eval_status: self.summary.num_bad += 1 diff --git a/docs/dataset/csv.md b/docs/dataset/csv.md new file mode 100644 index 00000000..66837544 --- /dev/null +++ b/docs/dataset/csv.md @@ -0,0 +1,250 @@ +# CSV 数据集读取功能说明 + +## 功能概述 + +Dingo 现已支持 CSV 文件的流式读取,提供完整的 CSV 数据处理能力。 + +## 主要特性 + +✅ **流式读取** - 使用 Python 标准库 `csv` 包,逐行处理,适合大文件 +✅ **多种格式** - 支持不同的 CSV 方言(excel、excel-tab、unix 等) +✅ **多种编码** - 支持 UTF-8、GBK、GB2312、Latin1 等编码 +✅ **灵活列名** - 支持带/不带列名的 CSV,自动使用 `column_x` 格式 +✅ **自定义分隔符** - 支持逗号、分号、Tab 等任意分隔符 +✅ **特殊字符处理** - 正确处理引号、逗号、多行内容等特殊情况 + +## 配置参数 + +### DatasetCsvArgs 参数说明 + +```python +class DatasetCsvArgs(BaseModel): + has_header: bool = True # 第一行是否为列名 + encoding: str = 'utf-8' # 文件编码 + dialect: str = 'excel' # CSV 格式方言 + delimiter: str | None = None # 自定义分隔符 + quotechar: str = '"' # 引号字符 +``` + +### 参数详解 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `has_header` | bool | True | 第一行是否为列名。False 时使用 `column_0`, `column_1` 等 | +| `encoding` | str | 'utf-8' | 文件编码,支持 utf-8、gbk、gb2312、latin1 等 | +| `dialect` | str | 'excel' | CSV 格式:excel(逗号)、excel-tab(Tab)、unix 等 | +| `delimiter` | str\|None | None | 自定义分隔符,优先级高于 dialect | +| `quotechar` | str | '"' | 引号字符 | + +## 使用示例 + +### 1. 标准 CSV(逗号分隔,带列名) + +```python +from dingo.config import InputArgs +from dingo.exec import Executor + +input_data = { + "input_path": "data.csv", + "dataset": { + "source": "local", + "format": "csv", + "csv_config": { + "has_header": True, + "encoding": "utf-8", + "dialect": "excel", + } + }, + "evaluator": [ + { + "fields": {"id":"id", "content": "content"}, + "evals": [ + {"name": "RuleColonEnd"} + ] + } + ] +} + +input_args = InputArgs(**input_data) +executor = Executor.exec_map["local"](input_args) +result = executor.execute() +``` + +### 2. 无列名 CSV + +```python +"csv_config": { + "has_header": False, # 第一行不是列名 + "encoding": "utf-8", +} +# 数据将使用 column_0, column_1, column_2 等作为列名 +``` + +### 3. Tab 分隔的 CSV + +```python +"csv_config": { + "has_header": True, + "dialect": "excel-tab", # Tab 分隔格式 +} +``` + +### 4. 自定义分隔符(分号) + +```python +"csv_config": { + "has_header": True, + "delimiter": ";", # 使用分号分隔 +} +``` + +### 5. GBK 编码(中文 Windows) + +```python +"csv_config": { + "has_header": True, + "encoding": "gbk", # GBK 编码 +} +``` + +## 运行测试 + +```bash +# 使用 conda 环境运行测试 +conda activate dingo +python test/scripts/dataset/test_csv_dataset.py +``` + +## 数据格式 + +CSV 文件的每一行会被转换为 JSON 格式,列名作为 JSON 的键: + +**CSV 文件:** +```csv +id,content,label +1,测试数据,good +2,第二条,bad +``` + +**转换后的 JSON:** +```json +{"id": "1", "content": "测试数据", "label": "good"} +{"id": "2", "content": "第二条", "label": "bad"} +``` + +**无列名时(has_header=False):** +```json +{"column_0": "1", "column_1": "测试数据", "column_2": "good"} +{"column_0": "2", "column_1": "第二条", "column_2": "bad"} +``` + +## 特殊情况处理 + +### 1. 包含逗号的内容 +CSV 标准会自动用引号包裹: +```csv +id,content +1,"包含逗号,的内容" +``` + +### 2. 包含引号的内容 +使用双引号转义: +```csv +id,content +1,"包含""引号""的内容" +``` + +### 3. 多行内容 +CSV 标准支持多行内容: +```csv +id,content +1,"第一行 +第二行" +``` + +### 4. 空值处理 +空单元格会转换为空字符串: +```csv +id,content,label +1,,good +``` +转换为: +```json +{"id": "1", "content": "", "label": "good"} +``` + +## 性能特性 + +### 流式读取 +- 使用 `csv.reader` 逐行读取,不会一次性加载整个文件到内存 +- 适合处理几 GB 的大型 CSV 文件 +- 可以在处理过程中随时中断,不影响性能 + +### 内存占用 +- 只保存当前处理的一行数据 +- 对大文件非常友好 +- 测试表明可以流畅处理包含数百万行的 CSV 文件 + +## 常见编码 + +| 编码 | 使用场景 | +|------|----------| +| utf-8 | 默认编码,支持所有语言 | +| gbk | 中文 Windows 系统常用 | +| gb2312 | 简体中文旧标准 | +| latin1 | 西欧语言 | +| iso-8859-1 | 与 latin1 相同 | +| cp1252 | Windows 西欧编码 | + +## 支持的 CSV 方言 + +| 方言 | 分隔符 | 说明 | +|------|--------|------| +| excel | 逗号 | 标准 Excel CSV 格式 | +| excel-tab | Tab | Excel 的 Tab 分隔格式 | +| unix | 逗号 | Unix 风格的 CSV | + +## 技术实现 + +### 核心文件 +1. `dingo/config/input_args.py` - 配置参数定义 +2. `dingo/data/datasource/local.py` - CSV 文件读取逻辑 +3. `dingo/data/converter/base.py` - CSV 数据转换器 + +### 实现要点 +- 使用 Python 标准库 `csv` 模块 +- 支持流式读取,避免内存溢出 +- 完整的错误处理和友好的错误提示 + +## 故障排查 + +### 编码错误 +``` +UnicodeDecodeError: 'utf-8' codec can't decode... +``` +**解决方案:** 尝试使用 `gbk` 或其他编码 + +### 分隔符错误 +数据列数不匹配或解析错误 +**解决方案:** 检查并设置正确的 `delimiter` 参数 + +### 空文件错误 +``` +RuntimeError: CSV file is empty +``` +**解决方案:** 检查文件是否为空或格式是否正确 + +## 最佳实践 + +1. **编码选择**:优先尝试 UTF-8,如果失败再尝试 GBK +2. **大文件处理**:利用流式读取特性,不要尝试一次性加载 +3. **数据验证**:在 evaluator 中添加必要的数据验证规则 +4. **列名规范**:建议使用带列名的 CSV,便于数据追踪 +5. **测试先行**:在处理大批量数据前,先用小样本测试配置 + + +## 相关文档 + +- [Excel 读取文档](excel.md) +- [数据集配置文档](../config.md) +- [评估器配置文档](../rules.md) diff --git a/docs/dataset/excel.md b/docs/dataset/excel.md new file mode 100644 index 00000000..bbf991df --- /dev/null +++ b/docs/dataset/excel.md @@ -0,0 +1,292 @@ +# Excel 数据集读取功能说明 + +## 功能概述 + +Dingo 现已支持 Excel 文件的流式读取,同时支持 `.xlsx` 和 `.xls` 两种格式,提供完整的 Excel 数据处理能力。 + +## 主要特性 + +✅ **流式读取** - 使用只读模式加载工作簿,逐行处理,适合大文件 +✅ **多种格式** - 同时支持 `.xlsx`(使用 openpyxl)和 `.xls`(使用 xlrd)格式 +✅ **多工作表** - 支持通过索引或名称选择指定工作表 +✅ **灵活列名** - 支持带/不带列名的 Excel,自动使用数字索引格式 +✅ **自动类型** - 自动处理数字、文本、日期等多种数据类型 +✅ **空值处理** - 正确处理空单元格、空行等特殊情况 + +## 配置参数 + +### DatasetExcelArgs 参数说明 + +```python +class DatasetExcelArgs(BaseModel): + sheet_name: str | int = 0 # 工作表索引或名称 + has_header: bool = True # 第一行是否为列名 +``` + +### 参数详解 + +| 参数 | 类型 | 默认值 | 说明 | +|------|-----|--------|------| +| `sheet_name` | str|int | 0 | 工作表选择。整数表示索引(从0开始),字符串表示工作表名称 | +| `has_header` | bool | True | 第一行是否为列名。False 时使用 `0`, `1`, `2` 等数字作为列名 | + +## 使用示例 + +### 1. 标准 Excel(带列名,第一个工作表) + +```python +from dingo.config import InputArgs +from dingo.exec import Executor + +input_data = { + "input_path": "data.xlsx", + "dataset": { + "source": "local", + "format": "excel", + "excel_config": { + "sheet_name": 0, + "has_header": True, + } + }, + "evaluator": [ + { + "fields": {"id":"id", "content": "content"}, + "evals": [ + {"name": "RuleColonEnd"} + ] + } + ] +} + +input_args = InputArgs(**input_data) +executor = Executor.exec_map["local"](input_args) +result = executor.execute() +``` + +### 2. 无列名 Excel + +```python +"excel_config": { + "sheet_name": 0, + "has_header": False, # 第一行不是列名 +} +# 数据将使用 0, 1, 2, 3 等作为列名 +``` + +### 3. 通过索引选择工作表 + +```python +"excel_config": { + "sheet_name": 1, # 读取第二个工作表(索引从0开始) + "has_header": True, +} +``` + +### 4. 通过名称选择工作表 + +```python +"excel_config": { + "sheet_name": "销售数据", # 使用工作表名称 + "has_header": True, +} +``` + +### 5. 读取 .xls 格式文件 + +```python +input_data = { + "input_path": "data.xls", # 旧版 Excel 格式 + "dataset": { + "source": "local", + "format": "excel", + "excel_config": { + "sheet_name": 0, + "has_header": True, + } + }, + # ... 其他配置 +} +``` + +## 运行测试 + +```bash +# 使用 conda 环境运行测试 +conda activate dingo +python test/scripts/dataset/test_excel_dataset.py +``` + +## 数据格式 + +Excel 文件的每一行会被转换为 JSON 格式,列名作为 JSON 的键: + +**Excel 文件:** + +| 参数 | 类型 | 默认值 | +|------|-----|--------| +| 1 | 测试数据 | good | +| 2 | 第二条 | bad | + +**转换后的 JSON:** +```json +{"id": 1, "content": "测试数据", "label": "good"} +{"id": 2, "content": "第二条", "label": "bad"} +``` + +**无列名时(has_header=False):** +```json +{"0": 1, "1": "测试数据", "2": "good"} +{"0": 2, "1": "第二条", "2": "bad"} +``` + +## 特殊情况处理 + +### 1. 多个工作表 + +Excel 文件可以包含多个工作表,使用 `sheet_name` 参数选择: + +```python +# 方式1: 通过索引选择 +"sheet_name": 0 # 第一个工作表 +"sheet_name": 1 # 第二个工作表 + +# 方式2: 通过名称选择 +"sheet_name": "Sheet1" +"sheet_name": "销售数据" +``` + +### 2. 空值处理 + +空单元格会转换为空字符串: + +| id | content | label | +|----|---------|-------| +| 1 | | good | + +转换为: +```json +{"id": 1, "content": "", "label": "good"} +``` + +### 3. 空行跳过 + +完全空的行会被自动跳过,不会出现在输出中。 + +### 4. 数据类型自动转换 + +Excel 的各种数据类型会自动转换: +- **数字**: 保持为数字类型(整数或浮点数) +- **文本**: 保持为字符串 +- **日期**: 转换为 Python datetime 对象的字符串表示 +- **公式**: 读取计算后的值(使用 `data_only=True`) + +### 5. 列名缺失或重复 + +如果标题行中有空单元格,会自动使用 `Column_x` 格式: + +| name | | age | +|------|---|-----| +| 张三 | 25 | 北京 | + +转换为: +```json +{"name": "张三", "Column_1": "25", "age": "北京"} +``` + +## 性能特性 + +### 流式读取 +- 使用 `openpyxl` 的只读模式(`read_only=True`)和 `xlrd` 的按需加载(`on_demand=True`) +- 逐行处理,不会一次性加载整个文件到内存 +- 适合处理几十 MB 到几百 MB 的大型 Excel 文件 +- 可以在处理过程中随时中断,不影响性能 + +### 内存占用 +- 只保存当前处理的一行数据 +- 对大文件非常友好 +- 相比一次性加载整个工作簿,内存占用大幅降低 + + +## 依赖库 + +### .xlsx 格式 (推荐) +```bash +pip install openpyxl +``` + +### .xls 格式(旧版 Excel) +```bash +pip install xlrd +``` + +### 完整安装 +```bash +# 同时支持两种格式 +pip install openpyxl xlrd +``` + +## 支持的 Excel 格式 + +| 格式 | 依赖库 | 说明 | +|------|--------|------| +| .xlsx | openpyxl | Excel 2007+ 标准格式,推荐使用 | +| .xls | xlrd | Excel 97-2003 旧格式 | + +## 技术实现 + +### 核心文件 +1. `dingo/config/input_args.py` - 配置参数定义 +2. `dingo/data/datasource/local.py` - Excel 文件读取逻辑 + - `_load_excel_file_xlsx()` - 处理 .xlsx 格式 + - `_load_excel_file_xls()` - 处理 .xls 格式 +3. `dingo/data/converter/base.py` - Excel 数据转换器 + +## 故障排查 + +### 缺少依赖库 +``` +RuntimeError: openpyxl is missing. Please install it using: pip install openpyxl +``` +**解决方案:** +```bash +pip install openpyxl # 用于 .xlsx 文件 +pip install xlrd # 用于 .xls 文件 +``` + +### 工作表不存在 +``` +RuntimeError: Sheet "数据表" not found in Excel file. Available sheets: ['Sheet1', 'Sheet2'] +``` +**解决方案:** 检查工作表名称是否正确,或使用数字索引(从0开始) + +### 工作表索引越界 +``` +RuntimeError: Sheet index 3 out of range. Total sheets: 2 +``` +**解决方案:** 检查工作表索引是否正确,记住索引从 0 开始 + +### 空文件错误 +``` +RuntimeError: Excel file "data.xlsx" is empty +``` +**解决方案:** 检查文件是否为空或第一个工作表是否包含数据 + +### 文件格式错误 +``` +RuntimeError: Failed to read .xlsx file "data.xlsx": ... +``` +**解决方案:** +1. 确认文件是有效的 Excel 文件 +2. 尝试在 Excel 中打开并另存为新文件 +3. 检查文件是否损坏 + + +## 相关文档 +- [数据集配置文档](../config.md) +- [评估器配置文档](../rules.md) + +## 示例代码 + +完整的示例代码可以在以下位置找到: +- `examples/dataset/excel.py` - 基本使用示例 +- `test/scripts/dataset/test_excel_dataset.py` - 完整测试用例 diff --git a/examples/dataset/example_csv.py b/examples/dataset/example_csv.py new file mode 100644 index 00000000..c0854851 --- /dev/null +++ b/examples/dataset/example_csv.py @@ -0,0 +1,42 @@ +import os +from pathlib import Path + +from dingo.config import InputArgs +from dingo.exec import Executor + +if __name__ == '__main__': + # 获取项目根目录 + root_dir = Path(__file__).parent.parent.parent + input_data = { + "input_path": str(root_dir / "test/data/test_local_csv.csv"), + "dataset": { + "source": "local", + "format": "csv", + "csv_config": { + "has_header": True, # 第一行是否为列名 + "encoding": "utf-8", # 文件编码 + "dialect": "excel", # CSV 格式 + # "delimiter": ",", # 可选:自定义分隔符 + } + }, + "executor": { + "result_save": { + "bad": True, + "good": True, + "raw": True, + } + }, + "evaluator": [ + { + "fields": {"id":"id", "content": "content"}, + "evals": [ + {"name": "RuleColonEnd"}, + {"name": "RuleSpecialCharacter"} + ] + } + ] + } + input_args = InputArgs(**input_data) + executor = Executor.exec_map["local"](input_args) + result = executor.execute() + print(result) diff --git a/test/data/test_local_csv.csv b/test/data/test_local_csv.csv new file mode 100644 index 00000000..4f2753a5 --- /dev/null +++ b/test/data/test_local_csv.csv @@ -0,0 +1,7 @@ +id,content,label +1,"这是第一条测试数据,用于检查CSV读取功能。",good +2,"第二条数据包含特殊字符:@#$%!",bad +3,"第三条数据测试多行 +内容的处理",good +4,"测试引号内的""双引号""",good +5,"测试逗号,在内容中",bad diff --git a/test/scripts/dataset/test_csv_dataset.py b/test/scripts/dataset/test_csv_dataset.py new file mode 100644 index 00000000..f46949b6 --- /dev/null +++ b/test/scripts/dataset/test_csv_dataset.py @@ -0,0 +1,646 @@ +""" +CSV Dataset 测试文件 + +测试 CSV 文件的流式读取功能,支持不同编码、不同分隔符、不同格式 +""" + +import csv +import json +import os +import tempfile + +from dingo.config import DatasetArgs, DatasetCsvArgs, InputArgs +from dingo.data.dataset.local import LocalDataset +from dingo.data.datasource.local import LocalDataSource + + +def create_test_csv_file(file_path: str, has_header: bool = True, encoding: str = 'utf-8', delimiter: str = ','): + """创建测试用的 CSV 文件""" + try: + with open(file_path, 'w', encoding=encoding, newline='') as f: + writer = csv.writer(f, delimiter=delimiter) + + if has_header: + # 添加表头 + writer.writerow(["姓名", "年龄", "城市", "分数"]) + + # 添加数据 + writer.writerow(["张三", "25", "北京", "95.5"]) + writer.writerow(["李四", "30", "上海", "88.0"]) + writer.writerow(["王五", "28", "广州", "92.3"]) + writer.writerow(["赵六", "35", "深圳", "87.8"]) + + return True + except Exception as e: + print(f"⚠ 创建 CSV 文件失败: {e}") + return False + + +def create_test_csv_with_special_chars(file_path: str, encoding: str = 'utf-8'): + """创建包含特殊字符的测试 CSV 文件""" + try: + with open(file_path, 'w', encoding=encoding, newline='') as f: + writer = csv.writer(f) + + # 添加表头 + writer.writerow(["id", "content", "label"]) + + # 添加包含特殊字符的数据 + writer.writerow(["1", "这是第一条测试数据,用于检查CSV读取功能。", "good"]) + writer.writerow(["2", "第二条数据包含特殊字符:@#$%!", "bad"]) + writer.writerow(["3", "第三条数据测试多行\n内容的处理", "good"]) + writer.writerow(["4", '测试引号内的"双引号"', "good"]) + writer.writerow(["5", "测试逗号,在内容中", "bad"]) + + return True + except Exception as e: + print(f"⚠ 创建特殊字符 CSV 文件失败: {e}") + return False + + +def test_csv_with_header(): + """测试有表头的标准 CSV 文件""" + print("=" * 60) + print("测试标准 CSV 文件(逗号分隔,有表头)") + print("=" * 60) + + # 创建临时文件 + temp_dir = tempfile.mkdtemp() + csv_file = os.path.join(temp_dir, "test_data_with_header.csv") + + try: + # 创建测试文件 + if not create_test_csv_file(csv_file, has_header=True): + return + + print(f"✓ 创建测试文件: {csv_file}") + + # 配置参数 + csv_config = DatasetCsvArgs( + has_header=True, # 第一行是表头 + encoding='utf-8', + dialect='excel' + ) + + dataset_config = DatasetArgs( + source="local", + format="csv", + csv_config=csv_config + ) + + input_args = InputArgs( + task_name="csv_test", + input_path=csv_file, + output_path="outputs/csv_test/", + dataset=dataset_config, + evaluator=[] + ) + + print("✓ 配置参数创建成功") + + # 创建数据源和数据集 + datasource = LocalDataSource(input_args=input_args) + print("✓ LocalDataSource 创建成功") + + dataset = LocalDataset(source=datasource, name="test_csv_dataset") + print("✓ LocalDataset 创建成功") + + # 流式读取数据 + print("\n开始流式读取数据:") + count = 0 + for idx, data in enumerate(dataset.get_data()): + count += 1 + print(f" [{idx + 1}] {data}") + + # 验证数据格式 + if idx == 0: + # 第一行数据应该有 "姓名", "年龄", "城市", "分数" 这些字段 + data_dict = data.to_dict() + assert "姓名" in data_dict, "数据缺少 '姓名' 字段" + assert "年龄" in data_dict, "数据缺少 '年龄' 字段" + assert "城市" in data_dict, "数据缺少 '城市' 字段" + assert "分数" in data_dict, "数据缺少 '分数' 字段" + # 也可以直接通过属性访问 + assert hasattr(data, '姓名'), "数据对象缺少 '姓名' 属性" + print("✓ 数据格式验证通过") + + assert count == 4, f"期望读取 4 行数据,实际读取了 {count} 行" + print(f"\n✓ 成功读取 {count} 条数据") + + print("\n" + "=" * 60) + print("✓ 测试通过!") + print("=" * 60) + + finally: + # 清理临时文件 + import shutil + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + print(f"\n✓ 清理临时文件: {temp_dir}") + + +def test_csv_without_header(): + """测试无表头的 CSV 文件(使用 column_x)""" + print("\n" + "=" * 60) + print("测试 CSV 文件(无表头,使用 column_x)") + print("=" * 60) + + # 创建临时文件 + temp_dir = tempfile.mkdtemp() + csv_file = os.path.join(temp_dir, "test_data_without_header.csv") + + try: + # 创建测试文件 + if not create_test_csv_file(csv_file, has_header=False): + return + + print(f"✓ 创建测试文件: {csv_file}") + + # 配置参数 + csv_config = DatasetCsvArgs( + has_header=False, # 第一行不是表头 + encoding='utf-8', + dialect='excel' + ) + + dataset_config = DatasetArgs( + source="local", + format="csv", + csv_config=csv_config + ) + + input_args = InputArgs( + task_name="csv_test_no_header", + input_path=csv_file, + output_path="outputs/csv_test/", + dataset=dataset_config, + evaluator=[] + ) + + print("✓ 配置参数创建成功") + + # 创建数据源和数据集 + datasource = LocalDataSource(input_args=input_args) + dataset = LocalDataset(source=datasource, name="test_csv_no_header") + + # 流式读取数据 + print("\n开始流式读取数据:") + count = 0 + for idx, data in enumerate(dataset.get_data()): + count += 1 + print(f" [{idx + 1}] {data}") + + # 验证数据格式(使用 column_x 作为列名) + if idx == 0: + data_dict = data.to_dict() + assert "column_0" in data_dict, "数据缺少 'column_0' 字段" + assert "column_1" in data_dict, "数据缺少 'column_1' 字段" + assert "column_2" in data_dict, "数据缺少 'column_2' 字段" + assert "column_3" in data_dict, "数据缺少 'column_3' 字段" + # 也可以直接通过属性访问 + assert hasattr(data, 'column_0'), "数据对象缺少 'column_0' 属性" + print("✓ 数据格式验证通过(使用 column_x 作为键)") + + assert count == 4, f"期望读取 4 行数据,实际读取了 {count} 行" + print(f"\n✓ 成功读取 {count} 条数据") + + print("\n" + "=" * 60) + print("✓ 测试通过!") + print("=" * 60) + + finally: + # 清理临时文件 + import shutil + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + print(f"\n✓ 清理临时文件: {temp_dir}") + + +def test_csv_tab_delimiter(): + """测试 Tab 分隔的 CSV 文件""" + print("\n" + "=" * 60) + print("测试 Tab 分隔的 CSV 文件") + print("=" * 60) + + # 创建临时文件 + temp_dir = tempfile.mkdtemp() + csv_file = os.path.join(temp_dir, "test_data_tab.csv") + + try: + # 创建测试文件(Tab 分隔) + if not create_test_csv_file(csv_file, has_header=True, delimiter='\t'): + return + + print(f"✓ 创建测试文件: {csv_file}") + + # 配置参数 + csv_config = DatasetCsvArgs( + has_header=True, + encoding='utf-8', + dialect='excel-tab' # Tab 分隔格式 + ) + + dataset_config = DatasetArgs( + source="local", + format="csv", + csv_config=csv_config + ) + + input_args = InputArgs( + task_name="csv_test_tab", + input_path=csv_file, + output_path="outputs/csv_test/", + dataset=dataset_config, + evaluator=[] + ) + + print("✓ 配置参数创建成功") + + # 创建数据源和数据集 + datasource = LocalDataSource(input_args=input_args) + dataset = LocalDataset(source=datasource, name="test_csv_tab") + + # 流式读取数据 + print("\n开始流式读取数据:") + count = 0 + for idx, data in enumerate(dataset.get_data()): + count += 1 + print(f" [{idx + 1}] {data}") + + # 验证数据格式 + if idx == 0: + data_dict = data.to_dict() + assert "姓名" in data_dict, "数据缺少 '姓名' 字段" + assert "年龄" in data_dict, "数据缺少 '年龄' 字段" + print("✓ 数据格式验证通过") + + assert count == 4, f"期望读取 4 行数据,实际读取了 {count} 行" + print(f"\n✓ 成功读取 {count} 条数据") + + print("\n" + "=" * 60) + print("✓ 测试通过!") + print("=" * 60) + + finally: + # 清理临时文件 + import shutil + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + print(f"\n✓ 清理临时文件: {temp_dir}") + + +def test_csv_custom_delimiter(): + """测试自定义分隔符(分号)的 CSV 文件""" + print("\n" + "=" * 60) + print("测试自定义分隔符(分号)的 CSV 文件") + print("=" * 60) + + # 创建临时文件 + temp_dir = tempfile.mkdtemp() + csv_file = os.path.join(temp_dir, "test_data_semicolon.csv") + + try: + # 创建测试文件(分号分隔) + if not create_test_csv_file(csv_file, has_header=True, delimiter=';'): + return + + print(f"✓ 创建测试文件: {csv_file}") + + # 配置参数 + csv_config = DatasetCsvArgs( + has_header=True, + encoding='utf-8', + dialect='excel', + delimiter=';' # 自定义分隔符:分号 + ) + + dataset_config = DatasetArgs( + source="local", + format="csv", + csv_config=csv_config + ) + + input_args = InputArgs( + task_name="csv_test_semicolon", + input_path=csv_file, + output_path="outputs/csv_test/", + dataset=dataset_config, + evaluator=[] + ) + + print("✓ 配置参数创建成功") + + # 创建数据源和数据集 + datasource = LocalDataSource(input_args=input_args) + dataset = LocalDataset(source=datasource, name="test_csv_semicolon") + + # 流式读取数据 + print("\n开始流式读取数据:") + count = 0 + for idx, data in enumerate(dataset.get_data()): + count += 1 + print(f" [{idx + 1}] {data}") + + # 验证数据格式 + if idx == 0: + data_dict = data.to_dict() + assert "姓名" in data_dict, "数据缺少 '姓名' 字段" + assert "年龄" in data_dict, "数据缺少 '年龄' 字段" + print("✓ 数据格式验证通过") + + assert count == 4, f"期望读取 4 行数据,实际读取了 {count} 行" + print(f"\n✓ 成功读取 {count} 条数据") + + print("\n" + "=" * 60) + print("✓ 测试通过!") + print("=" * 60) + + finally: + # 清理临时文件 + import shutil + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + print(f"\n✓ 清理临时文件: {temp_dir}") + + +def test_csv_gbk_encoding(): + """测试 GBK 编码的 CSV 文件""" + print("\n" + "=" * 60) + print("测试 GBK 编码的 CSV 文件") + print("=" * 60) + + # 创建临时文件 + temp_dir = tempfile.mkdtemp() + csv_file = os.path.join(temp_dir, "test_data_gbk.csv") + + try: + # 创建测试文件(GBK 编码) + if not create_test_csv_file(csv_file, has_header=True, encoding='gbk'): + return + + print(f"✓ 创建测试文件: {csv_file}") + + # 配置参数 + csv_config = DatasetCsvArgs( + has_header=True, + encoding='gbk', # GBK 编码 + dialect='excel' + ) + + dataset_config = DatasetArgs( + source="local", + format="csv", + csv_config=csv_config + ) + + input_args = InputArgs( + task_name="csv_test_gbk", + input_path=csv_file, + output_path="outputs/csv_test/", + dataset=dataset_config, + evaluator=[] + ) + + print("✓ 配置参数创建成功") + + # 创建数据源和数据集 + datasource = LocalDataSource(input_args=input_args) + dataset = LocalDataset(source=datasource, name="test_csv_gbk") + + # 流式读取数据 + print("\n开始流式读取数据:") + count = 0 + for idx, data in enumerate(dataset.get_data()): + count += 1 + print(f" [{idx + 1}] {data}") + + # 验证数据格式 + if idx == 0: + data_dict = data.to_dict() + assert "姓名" in data_dict, "数据缺少 '姓名' 字段" + assert "年龄" in data_dict, "数据缺少 '年龄' 字段" + print("✓ 数据格式验证通过(GBK 编码正确解析)") + + assert count == 4, f"期望读取 4 行数据,实际读取了 {count} 行" + print(f"\n✓ 成功读取 {count} 条数据") + + print("\n" + "=" * 60) + print("✓ 测试通过!") + print("=" * 60) + + finally: + # 清理临时文件 + import shutil + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + print(f"\n✓ 清理临时文件: {temp_dir}") + + +def test_csv_special_characters(): + """测试包含特殊字符的 CSV 文件""" + print("\n" + "=" * 60) + print("测试包含特殊字符的 CSV 文件") + print("=" * 60) + + # 创建临时文件 + temp_dir = tempfile.mkdtemp() + csv_file = os.path.join(temp_dir, "test_data_special_chars.csv") + + try: + # 创建测试文件 + if not create_test_csv_with_special_chars(csv_file): + return + + print(f"✓ 创建测试文件: {csv_file}") + + # 配置参数 + csv_config = DatasetCsvArgs( + has_header=True, + encoding='utf-8', + dialect='excel' + ) + + dataset_config = DatasetArgs( + source="local", + format="csv", + csv_config=csv_config + ) + + input_args = InputArgs( + task_name="csv_test_special_chars", + input_path=csv_file, + output_path="outputs/csv_test/", + dataset=dataset_config, + evaluator=[] + ) + + print("✓ 配置参数创建成功") + + # 创建数据源和数据集 + datasource = LocalDataSource(input_args=input_args) + dataset = LocalDataset(source=datasource, name="test_csv_special_chars") + + # 流式读取数据 + print("\n开始流式读取数据:") + count = 0 + for idx, data in enumerate(dataset.get_data()): + count += 1 + print(f" [{idx + 1}] {data}") + + # 验证数据格式 + if idx == 0: + data_dict = data.to_dict() + assert "id" in data_dict, "数据缺少 'id' 字段" + assert "content" in data_dict, "数据缺少 'content' 字段" + assert "label" in data_dict, "数据缺少 'label' 字段" + print("✓ 数据格式验证通过") + + assert count == 5, f"期望读取 5 行数据,实际读取了 {count} 行" + print(f"\n✓ 成功读取 {count} 条数据(包含特殊字符、多行内容、引号等)") + + print("\n" + "=" * 60) + print("✓ 测试通过!") + print("=" * 60) + + finally: + # 清理临时文件 + import shutil + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + print(f"\n✓ 清理临时文件: {temp_dir}") + + +def test_stream_large_csv(): + """测试大文件的流式读取特性""" + print("\n" + "=" * 60) + print("测试流式读取特性(大文件)") + print("=" * 60) + + temp_dir = tempfile.mkdtemp() + csv_file = os.path.join(temp_dir, "large_test.csv") + + try: + # 创建包含较多数据的测试文件 + with open(csv_file, 'w', encoding='utf-8', newline='') as f: + writer = csv.writer(f) + + # 添加表头 + writer.writerow(["ID", "名称", "数值"]) + + # 添加 1000 行数据 + for i in range(1, 1001): + writer.writerow([str(i), f"项目_{i}", str(i * 1.5)]) + + print(f"✓ 创建包含 1000 行数据的测试文件") + + # 配置参数 + csv_config = DatasetCsvArgs( + has_header=True, + encoding='utf-8', + dialect='excel' + ) + + dataset_config = DatasetArgs( + source="local", + format="csv", + csv_config=csv_config + ) + + input_args = InputArgs( + task_name="stream_test", + input_path=csv_file, + output_path="outputs/stream_test/", + dataset=dataset_config, + evaluator=[] + ) + + # 创建数据源和数据集 + datasource = LocalDataSource(input_args=input_args) + dataset = LocalDataset(source=datasource, name="stream_test_dataset") + + # 只读取前 10 条,验证流式读取 + print("开始流式读取(只读取前 10 条):") + count = 0 + for idx, data in enumerate(dataset.get_data()): + if idx < 10: + print(f" [{idx + 1}] {data}") + count += 1 + if idx >= 9: # 只读取前 10 条就停止 + break + + print(f"\n✓ 流式读取验证通过(处理了 {count} 条数据后停止)") + print("✓ 流式读取特性工作正常,不需要一次性加载所有数据到内存") + + print("\n" + "=" * 60) + print("✓ 测试通过!") + print("=" * 60) + + finally: + # 清理临时文件 + import shutil + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + print(f"\n✓ 清理临时文件: {temp_dir}") + + +def test_csv_comprehensive(): + """综合测试 - 测试各种 CSV 功能的完整性""" + print("\n" + "=" * 60) + print("综合测试 - CSV 功能完整性验证") + print("=" * 60) + + print("\n功能列表:") + print(" 1. ✓ 标准 CSV 格式(逗号分隔)") + print(" 2. ✓ 无列名的 CSV(column_x 格式)") + print(" 3. ✓ 不同分隔符(Tab、分号等)") + print(" 4. ✓ 不同的 CSV 格式(dialect)") + print(" 5. ✓ 流式读取(适合大文件)") + print(" 6. ✓ 多行内容和特殊字符") + print(" 7. ✓ 自定义编码(utf-8, gbk 等)") + + print("\n配置参数说明:") + print(" - has_header: 第一行是否为列名(默认 True)") + print(" - encoding: 文件编码(默认 utf-8)") + print(" - dialect: CSV 格式(默认 excel)") + print(" - delimiter: 自定义分隔符(默认 None,根据 dialect 自动选择)") + print(" - quotechar: 引号字符(默认双引号)") + + print("\n" + "=" * 60) + print("✓ 综合测试完成!") + print("=" * 60) + + +if __name__ == "__main__": + print("\n") + print("╔" + "═" * 58 + "╗") + print("║" + " " * 16 + "CSV 数据集测试套件" + " " * 22 + "║") + print("╚" + "═" * 58 + "╝") + print("\n") + + # 测试标准 CSV + test_csv_with_header() + + # 测试无列名 CSV + test_csv_without_header() + + # 测试不同分隔符 + test_csv_tab_delimiter() + test_csv_custom_delimiter() + + # 测试不同编码 + test_csv_gbk_encoding() + + # 测试特殊字符 + test_csv_special_characters() + + # 测试流式读取 + test_stream_large_csv() + + # 综合测试 + test_csv_comprehensive() + + print("\n") + print("╔" + "═" * 58 + "╗") + print("║" + " " * 18 + "所有测试完成!" + " " * 23 + "║") + print("╚" + "═" * 58 + "╝") + print("\n")