3232)
3333from tests .artifacts .testdata import (
3434 MODEL_NAME ,
35+ TWITTER_COMPLAINTS_DATA_ARROW ,
36+ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW ,
3537 TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON ,
3638 TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL ,
3739 TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET ,
3840 TWITTER_COMPLAINTS_DATA_JSON ,
3941 TWITTER_COMPLAINTS_DATA_JSONL ,
4042 TWITTER_COMPLAINTS_DATA_PARQUET ,
43+ TWITTER_COMPLAINTS_TOKENIZED_ARROW ,
4144 TWITTER_COMPLAINTS_TOKENIZED_JSON ,
4245 TWITTER_COMPLAINTS_TOKENIZED_JSONL ,
4346 TWITTER_COMPLAINTS_TOKENIZED_PARQUET ,
6265 TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL ,
6366 set (["ID" , "Label" , "input" , "output" ]),
6467 ),
68+ (
69+ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW ,
70+ set (["ID" , "Label" , "input" , "output" , "sequence" ]),
71+ ),
6572 (
6673 TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET ,
6774 set (["ID" , "Label" , "input" , "output" ]),
8087 ]
8188 ),
8289 ),
90+ (
91+ TWITTER_COMPLAINTS_TOKENIZED_ARROW ,
92+ set (
93+ [
94+ "Tweet text" ,
95+ "ID" ,
96+ "Label" ,
97+ "text_label" ,
98+ "output" ,
99+ "input_ids" ,
100+ "labels" ,
101+ ]
102+ ),
103+ ),
83104 (
84105 TWITTER_COMPLAINTS_TOKENIZED_PARQUET ,
85106 set (
98119 TWITTER_COMPLAINTS_DATA_JSONL ,
99120 set (["Tweet text" , "ID" , "Label" , "text_label" , "output" ]),
100121 ),
122+ (
123+ TWITTER_COMPLAINTS_DATA_ARROW ,
124+ set (["Tweet text" , "ID" , "Label" , "text_label" , "output" ]),
125+ ),
101126 (
102127 TWITTER_COMPLAINTS_DATA_PARQUET ,
103128 set (["Tweet text" , "ID" , "Label" , "text_label" , "output" ]),
@@ -123,6 +148,11 @@ def test_load_dataset_with_datafile(datafile, column_names):
123148 set (["ID" , "Label" , "input" , "output" ]),
124149 "text_dataset_input_output_masking" ,
125150 ),
151+ (
152+ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW ,
153+ set (["ID" , "Label" , "input" , "output" , "sequence" ]),
154+ "text_dataset_input_output_masking" ,
155+ ),
126156 (
127157 TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET ,
128158 set (["ID" , "Label" , "input" , "output" ]),
@@ -163,6 +193,11 @@ def test_load_dataset_with_datafile(datafile, column_names):
163193 set (["Tweet text" , "ID" , "Label" , "text_label" , "output" ]),
164194 "apply_custom_data_template" ,
165195 ),
196+ (
197+ TWITTER_COMPLAINTS_DATA_ARROW ,
198+ set (["Tweet text" , "ID" , "Label" , "text_label" , "output" ]),
199+ "apply_custom_data_template" ,
200+ ),
166201 (
167202 TWITTER_COMPLAINTS_DATA_PARQUET ,
168203 set (["Tweet text" , "ID" , "Label" , "text_label" , "output" ]),
@@ -593,6 +628,12 @@ def test_process_dataargs(data_args):
593628 training_data_path = TWITTER_COMPLAINTS_TOKENIZED_JSONL ,
594629 )
595630 ),
631+ # ARROW pretokenized train datasets
632+ (
633+ configs .DataArguments (
634+ training_data_path = TWITTER_COMPLAINTS_TOKENIZED_ARROW ,
635+ )
636+ ),
596637 # PARQUET pretokenized train datasets
597638 (
598639 configs .DataArguments (
0 commit comments