Skip to content

Commit fbe6064

Browse files
feat: Add support to handle Parquet Dataset files via data config (#401)
Signed-off-by: Abhishek <[email protected]>
1 parent 89db915 commit fbe6064

File tree

7 files changed

+113
-1
lines changed

7 files changed

+113
-1
lines changed

tests/artifacts/testdata/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,30 @@
1919

2020
### Constants used for data
2121
DATA_DIR = os.path.join(os.path.dirname(__file__))
22+
PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet")
2223
TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json")
2324
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl")
25+
TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join(
26+
PARQUET_DATA_DIR, "twitter_complaints_small.parquet"
27+
)
2428
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON = os.path.join(
2529
DATA_DIR, "twitter_complaints_input_output.json"
2630
)
2731
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
2832
DATA_DIR, "twitter_complaints_input_output.jsonl"
2933
)
34+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join(
35+
PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet"
36+
)
3037
TWITTER_COMPLAINTS_TOKENIZED_JSON = os.path.join(
3138
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json"
3239
)
3340
TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join(
3441
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
3542
)
43+
TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join(
44+
PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet"
45+
)
3646
EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json")
3747
MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json")
3848
MODEL_NAME = "Maykeye/TinyLLama-v0"
Binary file not shown.
Binary file not shown.
Binary file not shown.

tests/data/test_data_preprocessing_utils.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@
3434
MODEL_NAME,
3535
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
3636
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
37+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
3738
TWITTER_COMPLAINTS_DATA_JSON,
3839
TWITTER_COMPLAINTS_DATA_JSONL,
40+
TWITTER_COMPLAINTS_DATA_PARQUET,
3941
TWITTER_COMPLAINTS_TOKENIZED_JSON,
4042
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
43+
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
4144
)
4245

4346
# Local
@@ -59,6 +62,10 @@
5962
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
6063
set(["ID", "Label", "input", "output"]),
6164
),
65+
(
66+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
67+
set(["ID", "Label", "input", "output"]),
68+
),
6269
(
6370
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
6471
set(
@@ -73,10 +80,28 @@
7380
]
7481
),
7582
),
83+
(
84+
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
85+
set(
86+
[
87+
"Tweet text",
88+
"ID",
89+
"Label",
90+
"text_label",
91+
"output",
92+
"input_ids",
93+
"labels",
94+
]
95+
),
96+
),
7697
(
7798
TWITTER_COMPLAINTS_DATA_JSONL,
7899
set(["Tweet text", "ID", "Label", "text_label", "output"]),
79100
),
101+
(
102+
TWITTER_COMPLAINTS_DATA_PARQUET,
103+
set(["Tweet text", "ID", "Label", "text_label", "output"]),
104+
),
80105
],
81106
)
82107
def test_load_dataset_with_datafile(datafile, column_names):
@@ -98,6 +123,11 @@ def test_load_dataset_with_datafile(datafile, column_names):
98123
set(["ID", "Label", "input", "output"]),
99124
"text_dataset_input_output_masking",
100125
),
126+
(
127+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
128+
set(["ID", "Label", "input", "output"]),
129+
"text_dataset_input_output_masking",
130+
),
101131
(
102132
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
103133
set(
@@ -113,11 +143,31 @@ def test_load_dataset_with_datafile(datafile, column_names):
113143
),
114144
"pretokenized_dataset",
115145
),
146+
(
147+
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
148+
set(
149+
[
150+
"Tweet text",
151+
"ID",
152+
"Label",
153+
"text_label",
154+
"output",
155+
"input_ids",
156+
"labels",
157+
]
158+
),
159+
"pretokenized_dataset",
160+
),
116161
(
117162
TWITTER_COMPLAINTS_DATA_JSONL,
118163
set(["Tweet text", "ID", "Label", "text_label", "output"]),
119164
"apply_custom_data_template",
120165
),
166+
(
167+
TWITTER_COMPLAINTS_DATA_PARQUET,
168+
set(["Tweet text", "ID", "Label", "text_label", "output"]),
169+
"apply_custom_data_template",
170+
),
121171
],
122172
)
123173
def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigname):
@@ -139,8 +189,14 @@ def test_load_dataset_with_datasetconfig(datafile, column_names, datasetconfigna
139189
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
140190
"text_dataset_input_output_masking",
141191
),
192+
(
193+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
194+
"text_dataset_input_output_masking",
195+
),
142196
(TWITTER_COMPLAINTS_TOKENIZED_JSONL, "pretokenized_dataset"),
197+
(TWITTER_COMPLAINTS_TOKENIZED_PARQUET, "pretokenized_dataset"),
143198
(TWITTER_COMPLAINTS_DATA_JSONL, "apply_custom_data_template"),
199+
(TWITTER_COMPLAINTS_DATA_PARQUET, "apply_custom_data_template"),
144200
],
145201
)
146202
def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname):
@@ -339,8 +395,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
339395
[
340396
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSON),
341397
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_JSONL),
398+
(APPLY_CUSTOM_TEMPLATE_YAML, TWITTER_COMPLAINTS_DATA_PARQUET),
342399
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSON),
343400
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_JSONL),
401+
(PRETOKENIZE_JSON_DATA_YAML, TWITTER_COMPLAINTS_TOKENIZED_PARQUET),
344402
(
345403
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
346404
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
@@ -349,6 +407,10 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
349407
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
350408
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
351409
),
410+
(
411+
TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
412+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
413+
),
352414
],
353415
)
354416
def test_process_dataconfig_file(data_config_path, data_path):
@@ -414,6 +476,15 @@ def test_process_dataconfig_file(data_config_path, data_path):
414476
response_template="\n### Label:",
415477
)
416478
),
479+
# single sequence PARQUET and response template
480+
(
481+
configs.DataArguments(
482+
training_data_path=TWITTER_COMPLAINTS_DATA_PARQUET,
483+
validation_data_path=TWITTER_COMPLAINTS_DATA_PARQUET,
484+
dataset_text_field="output",
485+
response_template="\n### Label:",
486+
)
487+
),
417488
# data formatter template with input/output JSON
418489
(
419490
configs.DataArguments(
@@ -432,6 +503,15 @@ def test_process_dataconfig_file(data_config_path, data_path):
432503
response_template="\n### Label:",
433504
)
434505
),
506+
# data formatter template with input/output PARQUET
507+
(
508+
configs.DataArguments(
509+
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
510+
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
511+
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
512+
response_template="\n### Label:",
513+
)
514+
),
435515
# input/output JSON with masking on input
436516
(
437517
configs.DataArguments(
@@ -446,6 +526,13 @@ def test_process_dataconfig_file(data_config_path, data_path):
446526
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
447527
)
448528
),
529+
# input/output PARQUET with masking on input
530+
(
531+
configs.DataArguments(
532+
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
533+
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
534+
)
535+
),
449536
],
450537
)
451538
def test_process_dataargs(data_args):
@@ -487,6 +574,13 @@ def test_process_dataargs(data_args):
487574
validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
488575
)
489576
),
577+
# PARQUET pretokenized train and validation datasets
578+
(
579+
configs.DataArguments(
580+
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
581+
validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
582+
)
583+
),
490584
# JSON pretokenized train datasets
491585
(
492586
configs.DataArguments(
@@ -499,6 +593,12 @@ def test_process_dataargs(data_args):
499593
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
500594
)
501595
),
596+
# PARQUET pretokenized train datasets
597+
(
598+
configs.DataArguments(
599+
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
600+
)
601+
),
502602
],
503603
)
504604
def test_process_dataargs_pretokenized(data_args):

tuning/data/data_processors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _process_dataset_configs(
105105
# In future the streaming etc go as kwargs of this function
106106
raw_dataset = self.load_dataset(d, splitName)
107107

108-
logging.info("Loaded raw dataset : {raw_datasets}")
108+
logging.info("Loaded raw dataset : %s", str(raw_dataset))
109109

110110
raw_datasets = DatasetDict()
111111

tuning/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def get_loader_for_filepath(file_path: str) -> str:
3131
return "text"
3232
if ext in (".json", ".jsonl"):
3333
return "json"
34+
if ext in (".parquet"):
35+
return "parquet"
3436
return ext
3537

3638

0 commit comments

Comments
 (0)