Skip to content

Commit 1de03d7

Browse files
authored
Merge pull request #310 from shijinpjlab/dev_csv
feat: support csv
2 parents 8bf0093 + 5a01ae8 commit 1de03d7

File tree

11 files changed

+1369
-8
lines changed

11 files changed

+1369
-8
lines changed

dingo/config/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from dingo.config.input_args import (DatasetArgs, DatasetExcelArgs, DatasetFieldArgs, DatasetHFConfigArgs, DatasetS3ConfigArgs, DatasetSqlArgs, EvalPipline, EvalPiplineConfig, # noqa E402.
2-
EvaluatorLLMArgs, EvaluatorRuleArgs, ExecutorArgs, ExecutorResultSaveArgs, InputArgs)
1+
from dingo.config.input_args import (DatasetArgs, DatasetCsvArgs, DatasetExcelArgs, DatasetFieldArgs, DatasetHFConfigArgs, DatasetS3ConfigArgs, DatasetSqlArgs, EvalPipline, # noqa E402.
2+
EvalPiplineConfig, EvaluatorLLMArgs, EvaluatorRuleArgs, ExecutorArgs, ExecutorResultSaveArgs, InputArgs)

dingo/config/input_args.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ class DatasetExcelArgs(BaseModel):
3232
has_header: bool = True # 第一行是否为列名,False 则使用列序号作为列名
3333

3434

35+
class DatasetCsvArgs(BaseModel):
36+
has_header: bool = True # 第一行是否为列名,False 则使用 column_x 作为列名
37+
encoding: str = 'utf-8' # 文件编码,默认 utf-8,支持 gbk, gb2312, latin1 等
38+
dialect: str = 'excel' # CSV 格式方言:excel(默认), excel-tab, unix 等
39+
delimiter: str | None = None # 分隔符,None 表示根据 dialect 自动选择
40+
quotechar: str = '"' # 引号字符,默认双引号
41+
42+
3543
class DatasetFieldArgs(BaseModel):
3644
id: str = ''
3745
prompt: str = ''
@@ -49,6 +57,7 @@ class DatasetArgs(BaseModel):
4957
s3_config: DatasetS3ConfigArgs = DatasetS3ConfigArgs()
5058
sql_config: DatasetSqlArgs = DatasetSqlArgs()
5159
excel_config: DatasetExcelArgs = DatasetExcelArgs()
60+
csv_config: DatasetCsvArgs = DatasetCsvArgs()
5261

5362

5463
class ExecutorResultSaveArgs(BaseModel):

dingo/data/converter/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,25 @@ def _convert(raw: Union[str, Dict]):
280280
return _convert
281281

282282

283+
@BaseConverter.register("csv")
284+
class CsvConverter(BaseConverter):
285+
"""CSV file converter."""
286+
287+
def __init__(self):
288+
super().__init__()
289+
290+
@classmethod
291+
def convertor(cls, input_args: InputArgs) -> Callable:
292+
def _convert(raw: Union[str, Dict]):
293+
j = raw
294+
if isinstance(raw, str):
295+
j = json.loads(raw)
296+
data_dict = j
297+
return Data(**data_dict)
298+
299+
return _convert
300+
301+
283302
@BaseConverter.register("listjson")
284303
class ListJsonConverter(BaseConverter):
285304
"""List json file converter."""

dingo/data/datasource/local.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,94 @@ def _load_excel_file_xlsx(self, path: str) -> Generator[str, None, None]:
142142
if wb:
143143
wb.close()
144144

145+
def _load_csv_file(self, path: str) -> Generator[str, None, None]:
146+
"""
147+
Load a CSV file and return its contents row by row as JSON strings.
148+
Supports streaming for large files, different encodings, and various CSV formats.
149+
150+
Args:
151+
path (str): The path to the CSV file.
152+
153+
Returns:
154+
Generator[str]: Each row as a JSON string with header keys.
155+
"""
156+
import csv
157+
158+
# 获取 CSV 配置
159+
has_header = self.input_args.dataset.csv_config.has_header
160+
encoding = self.input_args.dataset.csv_config.encoding
161+
dialect = self.input_args.dataset.csv_config.dialect
162+
delimiter = self.input_args.dataset.csv_config.delimiter
163+
quotechar = self.input_args.dataset.csv_config.quotechar
164+
165+
try:
166+
# 尝试使用指定的编码打开文件
167+
with open(path, 'r', encoding=encoding, newline='') as csvfile:
168+
# 设置 CSV reader 参数
169+
reader_kwargs = {
170+
'dialect': dialect,
171+
'quotechar': quotechar,
172+
}
173+
174+
# 如果指定了自定义分隔符,覆盖 dialect 的默认值
175+
if delimiter is not None:
176+
reader_kwargs['delimiter'] = delimiter
177+
178+
# 创建 CSV reader(流式读取)
179+
csv_reader = csv.reader(csvfile, **reader_kwargs)
180+
181+
# 处理标题行
182+
headers = None
183+
# first_row_data = None
184+
185+
try:
186+
first_row = next(csv_reader)
187+
except StopIteration:
188+
raise RuntimeError(f'CSV file "{path}" is empty')
189+
190+
if has_header:
191+
# The first row is the header
192+
headers = [str(h).strip() if h else f'column_{i}' for i, h in enumerate(first_row)]
193+
data_rows = csv_reader
194+
else:
195+
# Generate headers and treat the first row as data
196+
from itertools import chain
197+
headers = [f'column_{i}' for i in range(len(first_row))]
198+
data_rows = chain([first_row], csv_reader)
199+
200+
# Process all data rows in a single loop
201+
for row in data_rows:
202+
# Skip empty rows
203+
if not row or all(not cell.strip() for cell in row):
204+
continue
205+
206+
# Combine row data with headers into a dictionary, handling rows with fewer columns
207+
row_dict = {
208+
header: (row[i].strip() if row[i] else "") if i < len(row) else ""
209+
for i, header in enumerate(headers)
210+
}
211+
212+
# Yield the JSON string
213+
yield json.dumps(row_dict, ensure_ascii=False) + '\n'
214+
215+
except UnicodeDecodeError as e:
216+
# 编码错误提示
217+
raise RuntimeError(
218+
f'Failed to read CSV file "{path}" with encoding "{encoding}": {str(e)}. '
219+
f'Please try a different encoding (e.g., "gbk", "gb2312", "latin1", "iso-8859-1").'
220+
)
221+
except csv.Error as e:
222+
# CSV 格式错误
223+
raise RuntimeError(
224+
f'Failed to parse CSV file "{path}": {str(e)}. '
225+
f'Current dialect: "{dialect}". You may need to adjust the dialect or delimiter parameter.'
226+
)
227+
except Exception as e:
228+
raise RuntimeError(
229+
f'Failed to read CSV file "{path}": {str(e)}. '
230+
f'Please ensure the file is a valid CSV file.'
231+
)
232+
145233
def _load_excel_file_xls(self, path: str) -> Generator[str, None, None]:
146234
"""
147235
Load an .xls Excel file and return its contents row by row as JSON strings.
@@ -241,8 +329,13 @@ def _load_local_file(self) -> Generator[str, None, None]:
241329
by_line = self.input_args.dataset.format not in ["json", "listjson"]
242330

243331
for f in f_list:
332+
# Check if file is CSV
333+
if f.endswith('.csv'):
334+
if self.input_args.dataset.format != 'csv':
335+
raise RuntimeError(f'CSV file "{f}" is not supported. Please set dataset.format to "csv" to read CSV files.')
336+
yield from self._load_csv_file(f)
244337
# Check if file is Excel
245-
if f.endswith('.xlsx'):
338+
elif f.endswith('.xlsx'):
246339
if self.input_args.dataset.format != 'excel':
247340
raise RuntimeError(f'Excel file "{f}" is not supported. Please set dataset.format to "excel" to read Excel files.')
248341
yield from self._load_excel_file_xlsx(f)
@@ -278,7 +371,7 @@ def _load_local_file(self) -> Generator[str, None, None]:
278371
except UnicodeDecodeError as decode_error:
279372
raise RuntimeError(
280373
f'Failed to read file "{f}": Unsupported file format or encoding. '
281-
f'Dingo only supports UTF-8 text files (.jsonl, .json, .txt), Excel files (.xlsx, .xls) and .gz compressed text files. '
374+
f'Dingo only supports UTF-8 text files (.jsonl, .json, .txt), CSV files (.csv), Excel files (.xlsx, .xls) and .gz compressed text files. '
282375
f'Original error: {str(decode_error)}'
283376
)
284377
except Exception as e:

dingo/exec/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dingo.exec.base import ExecProto, Executor # noqa E402.
12
from dingo.exec.local import LocalExecutor # noqa E402.
23
from dingo.utils import log
34

@@ -6,5 +7,3 @@
67
except Exception as e:
78
log.warning("Spark Executor not imported. Open debug log for more details.")
89
log.debug(str(e))
9-
10-
from dingo.exec.base import ExecProto, Executor # noqa E402.

dingo/exec/local.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def execute(self) -> SummaryModel:
115115
self.summary.type_ratio[field_key] = {}
116116

117117
# 遍历 List[EvalDetail],同时收集指标分数和标签
118+
label_set = set()
118119
for eval_detail in eval_detail_list:
119120
# 收集指标分数(按 field_key 分组)
120121
if eval_detail.score is not None and eval_detail.metric:
@@ -123,8 +124,11 @@ def execute(self) -> SummaryModel:
123124
# 收集标签统计
124125
label_list = eval_detail.label if eval_detail.label else []
125126
for label in label_list:
126-
self.summary.type_ratio[field_key].setdefault(label, 0)
127-
self.summary.type_ratio[field_key][label] += 1
127+
label_set.add(label)
128+
129+
for label in label_set:
130+
self.summary.type_ratio[field_key].setdefault(label, 0)
131+
self.summary.type_ratio[field_key][label] += 1
128132

129133
if result_info.eval_status:
130134
self.summary.num_bad += 1

0 commit comments

Comments
 (0)