Skip to content

Commit e6f7a22

Browse files
authored
test: add arrow datasets and arrow unit tests (#403)
Signed-off-by: Will Johnson <[email protected]>
1 parent fbe6064 commit e6f7a22

File tree

6 files changed

+50
-0
lines changed

6 files changed

+50
-0
lines changed

tests/artifacts/testdata/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet")
2323
TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json")
2424
TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl")
25+
TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow")
2526
TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join(
2627
PARQUET_DATA_DIR, "twitter_complaints_small.parquet"
2728
)
@@ -31,6 +32,9 @@
3132
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join(
3233
DATA_DIR, "twitter_complaints_input_output.jsonl"
3334
)
35+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join(
36+
DATA_DIR, "twitter_complaints_input_output.arrow"
37+
)
3438
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join(
3539
PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet"
3640
)
@@ -40,6 +44,9 @@
4044
TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join(
4145
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl"
4246
)
47+
TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join(
48+
DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow"
49+
)
4350
TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join(
4451
PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet"
4552
)
13.5 KB
Binary file not shown.
3.84 KB
Binary file not shown.
Binary file not shown.

tests/data/test_data_preprocessing_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,15 @@
3232
)
3333
from tests.artifacts.testdata import (
3434
MODEL_NAME,
35+
TWITTER_COMPLAINTS_DATA_ARROW,
36+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
3537
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
3638
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
3739
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
3840
TWITTER_COMPLAINTS_DATA_JSON,
3941
TWITTER_COMPLAINTS_DATA_JSONL,
4042
TWITTER_COMPLAINTS_DATA_PARQUET,
43+
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
4144
TWITTER_COMPLAINTS_TOKENIZED_JSON,
4245
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
4346
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
@@ -62,6 +65,10 @@
6265
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
6366
set(["ID", "Label", "input", "output"]),
6467
),
68+
(
69+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
70+
set(["ID", "Label", "input", "output", "sequence"]),
71+
),
6572
(
6673
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
6774
set(["ID", "Label", "input", "output"]),
@@ -80,6 +87,20 @@
8087
]
8188
),
8289
),
90+
(
91+
TWITTER_COMPLAINTS_TOKENIZED_ARROW,
92+
set(
93+
[
94+
"Tweet text",
95+
"ID",
96+
"Label",
97+
"text_label",
98+
"output",
99+
"input_ids",
100+
"labels",
101+
]
102+
),
103+
),
83104
(
84105
TWITTER_COMPLAINTS_TOKENIZED_PARQUET,
85106
set(
@@ -98,6 +119,10 @@
98119
TWITTER_COMPLAINTS_DATA_JSONL,
99120
set(["Tweet text", "ID", "Label", "text_label", "output"]),
100121
),
122+
(
123+
TWITTER_COMPLAINTS_DATA_ARROW,
124+
set(["Tweet text", "ID", "Label", "text_label", "output"]),
125+
),
101126
(
102127
TWITTER_COMPLAINTS_DATA_PARQUET,
103128
set(["Tweet text", "ID", "Label", "text_label", "output"]),
@@ -123,6 +148,11 @@ def test_load_dataset_with_datafile(datafile, column_names):
123148
set(["ID", "Label", "input", "output"]),
124149
"text_dataset_input_output_masking",
125150
),
151+
(
152+
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW,
153+
set(["ID", "Label", "input", "output", "sequence"]),
154+
"text_dataset_input_output_masking",
155+
),
126156
(
127157
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
128158
set(["ID", "Label", "input", "output"]),
@@ -163,6 +193,11 @@ def test_load_dataset_with_datafile(datafile, column_names):
163193
set(["Tweet text", "ID", "Label", "text_label", "output"]),
164194
"apply_custom_data_template",
165195
),
196+
(
197+
TWITTER_COMPLAINTS_DATA_ARROW,
198+
set(["Tweet text", "ID", "Label", "text_label", "output"]),
199+
"apply_custom_data_template",
200+
),
166201
(
167202
TWITTER_COMPLAINTS_DATA_PARQUET,
168203
set(["Tweet text", "ID", "Label", "text_label", "output"]),
@@ -593,6 +628,12 @@ def test_process_dataargs(data_args):
593628
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
594629
)
595630
),
631+
# ARROW pretokenized train datasets
632+
(
633+
configs.DataArguments(
634+
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_ARROW,
635+
)
636+
),
596637
# PARQUET pretokenized train datasets
597638
(
598639
configs.DataArguments(

tuning/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def get_loader_for_filepath(file_path: str) -> str:
3131
return "text"
3232
if ext in (".json", ".jsonl"):
3333
return "json"
34+
if ext in (".arrow"):
35+
return "arrow"
3436
if ext in (".parquet"):
3537
return "parquet"
3638
return ext

0 commit comments

Comments
 (0)