Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 033214e

Browse files
authored
[Numpy] Fix SQuAD + Fix GLUE downloading (#1280)
* Update run_squad.py * Update run_squad.py * Update prepare_glue.py
1 parent 3c87457 commit 033214e

File tree

2 files changed

+64
-36
lines changed

2 files changed

+64
-36
lines changed

scripts/datasets/general_nlp_benchmark/prepare_glue.py

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,23 @@ def read_tsv_glue(tsv_file, num_skip=1, keep_column_names=False):
6868
nrows = len(elements)
6969
else:
7070
assert nrows == len(elements)
71-
return pd.DataFrame(out, columns=column_names)
71+
df = pd.DataFrame(out, columns=column_names)
72+
series_l = []
73+
for col_name in df.columns:
74+
idx = df[col_name].first_valid_index()
75+
val = df[col_name][idx]
76+
if isinstance(val, str):
77+
try:
78+
dat = pd.to_numeric(df[col_name])
79+
series_l.append(dat)
80+
continue
81+
except ValueError:
82+
pass
83+
finally:
84+
pass
85+
series_l.append(df[col_name])
86+
new_df = pd.DataFrame({name: series for name, series in zip(df.columns, series_l)})
87+
return new_df
7288

7389

7490
def read_jsonl_superglue(jsonl_file):
@@ -157,6 +173,13 @@ def read_sts(dir_path):
157173
else:
158174
df = df[[7, 8, 1, 9]]
159175
df.columns = ['sentence1', 'sentence2', 'genre', 'score']
176+
genre_l = []
177+
for ele in df['genre'].tolist():
178+
if ele == 'main-forum':
179+
genre_l.append('main-forums')
180+
else:
181+
genre_l.append(ele)
182+
df['genre'] = pd.Series(genre_l)
160183
df_dict[fold] = df
161184
return df_dict, None
162185

@@ -320,8 +343,8 @@ def read_rte_superglue(dir_path):
320343
def read_wic(dir_path):
321344
df_dict = dict()
322345
meta_data = dict()
323-
meta_data['entities1'] = {'type': 'entity', 'parent': 'sentence1'}
324-
meta_data['entities2'] = {'type': 'entity', 'parent': 'sentence2'}
346+
meta_data['entities1'] = {'type': 'entity', 'attrs': {'parent': 'sentence1'}}
347+
meta_data['entities2'] = {'type': 'entity', 'attrs': {'parent': 'sentence2'}}
325348

326349
for fold in ['train', 'val', 'test']:
327350
if fold != 'test':
@@ -340,13 +363,13 @@ def read_wic(dir_path):
340363
end2 = row['end2']
341364
if fold == 'test':
342365
out.append([sentence1, sentence2,
343-
(start1, end1),
344-
(start2, end2)])
366+
{'start': start1, 'end': end1},
367+
{'start': start2, 'end': end2}])
345368
else:
346369
label = row['label']
347370
out.append([sentence1, sentence2,
348-
(start1, end1),
349-
(start2, end2),
371+
{'start': start1, 'end': end1},
372+
{'start': start2, 'end': end2},
350373
label])
351374
df = pd.DataFrame(out, columns=columns)
352375
df_dict[fold] = df
@@ -357,8 +380,8 @@ def read_wsc(dir_path):
357380
df_dict = dict()
358381
tokenizer = WhitespaceTokenizer()
359382
meta_data = dict()
360-
meta_data['noun'] = {'type': 'entity', 'parent': 'text'}
361-
meta_data['pronoun'] = {'type': 'entity', 'parent': 'text'}
383+
meta_data['noun'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
384+
meta_data['pronoun'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
362385
for fold in ['train', 'val', 'test']:
363386
jsonl_path = os.path.join(dir_path, '{}.jsonl'.format(fold))
364387
df = read_jsonl_superglue(jsonl_path)
@@ -374,20 +397,20 @@ def read_wsc(dir_path):
374397
span2_text = target['span2_text']
375398
# Build entity
376399
# list of entities
377-
# 'entity': {'start': 0, 'end': 100}
400+
# 'entities': {'start': 0, 'end': 100}
378401
tokens, offsets = tokenizer.encode_with_offsets(text, str)
379402
pos_start1 = offsets[span1_index][0]
380403
pos_end1 = pos_start1 + len(span1_text)
381404
pos_start2 = offsets[span2_index][0]
382405
pos_end2 = pos_start2 + len(span2_text)
383406
if fold == 'test':
384407
samples.append({'text': text,
385-
'noun': (pos_start1, pos_end1),
386-
'pronoun': (pos_start2, pos_end2)})
408+
'noun': {'start': pos_start1, 'end': pos_end1},
409+
'pronoun': {'start': pos_start2, 'end': pos_end2}})
387410
else:
388411
samples.append({'text': text,
389-
'noun': (pos_start1, pos_end1),
390-
'pronoun': (pos_start2, pos_end2),
412+
'noun': {'start': pos_start1, 'end': pos_end1},
413+
'pronoun': {'start': pos_start2, 'end': pos_end2},
391414
'label': label})
392415
df = pd.DataFrame(samples)
393416
df_dict[fold] = df
@@ -406,8 +429,8 @@ def read_boolq(dir_path):
406429
def read_record(dir_path):
407430
df_dict = dict()
408431
meta_data = dict()
409-
meta_data['entities'] = {'type': 'entity', 'parent': 'text'}
410-
meta_data['answers'] = {'type': 'entity', 'parent': 'text'}
432+
meta_data['entities'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
433+
meta_data['answers'] = {'type': 'entity', 'attrs': {'parent': 'text'}}
411434
for fold in ['train', 'val', 'test']:
412435
if fold != 'test':
413436
columns = ['source', 'text', 'entities', 'query', 'answers']
@@ -422,15 +445,11 @@ def read_record(dir_path):
422445
passage = row['passage']
423446
text = passage['text']
424447
entities = passage['entities']
425-
entities = [(ele['start'], ele['end']) for ele in entities]
448+
entities = [{'start': ele['start'], 'end': ele['end']} for ele in entities]
426449
for qas in row['qas']:
427450
query = qas['query']
428451
if fold != 'test':
429-
answer_entities = []
430-
for answer in qas['answers']:
431-
start = answer['start']
432-
end = answer['end']
433-
answer_entities.append((start, end))
452+
answer_entities = qas['answers']
434453
out.append((source, text, entities, query, answer_entities))
435454
else:
436455
out.append((source, text, entities, query))
@@ -518,11 +537,15 @@ def format_mrpc(data_dir):
518537
os.makedirs(mrpc_dir, exist_ok=True)
519538
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
520539
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
521-
download(GLUE_TASK2PATH["mrpc"]['train'], mrpc_train_file)
522-
download(GLUE_TASK2PATH["mrpc"]['test'], mrpc_test_file)
540+
download(GLUE_TASK2PATH["mrpc"]['train'], mrpc_train_file,
541+
sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['train']])
542+
download(GLUE_TASK2PATH["mrpc"]['test'], mrpc_test_file,
543+
sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['test']])
523544
assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
524545
assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
525-
download(GLUE_TASK2PATH["mrpc"]['dev'], os.path.join(mrpc_dir, "dev_ids.tsv"))
546+
download(GLUE_TASK2PATH["mrpc"]['dev'],
547+
os.path.join(mrpc_dir, "dev_ids.tsv"),
548+
sha1_hash=_URL_FILE_STATS[GLUE_TASK2PATH["mrpc"]['dev']])
526549

527550
dev_ids = []
528551
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
@@ -575,7 +598,7 @@ def get_tasks(benchmark, task_names):
575598
@DATA_PARSER_REGISTRY.register('prepare_glue')
576599
def get_parser():
577600
parser = argparse.ArgumentParser()
578-
parser.add_argument("--benchmark", choices=['glue', 'superglue', 'sts'],
601+
parser.add_argument("--benchmark", choices=['glue', 'superglue'],
579602
default='glue', type=str)
580603
parser.add_argument("-d", "--data_dir", help="directory to save data to", type=str,
581604
default=None)
@@ -618,39 +641,44 @@ def main(args):
618641
base_dir = os.path.join(args.data_dir, 'rte_diagnostic')
619642
os.makedirs(base_dir, exist_ok=True)
620643
download(TASK2PATH['diagnostic'][0],
621-
path=os.path.join(base_dir, 'diagnostic.tsv'))
644+
path=os.path.join(base_dir, 'diagnostic.tsv'),
645+
sha1_hash=_URL_FILE_STATS[TASK2PATH['diagnostic'][0]])
622646
download(TASK2PATH['diagnostic'][1],
623-
path=os.path.join(base_dir, 'diagnostic-full.tsv'))
647+
path=os.path.join(base_dir, 'diagnostic-full.tsv'),
648+
sha1_hash=_URL_FILE_STATS[TASK2PATH['diagnostic'][1]])
624649
df = reader(base_dir)
625-
df.to_pickle(os.path.join(base_dir, 'diagnostic-full.pd.pkl'))
650+
df.to_parquet(os.path.join(base_dir, 'diagnostic-full.parquet'))
626651
else:
627652
for key, name in [('broadcoverage-diagnostic', 'AX-b'),
628653
('winogender-diagnostic', 'AX-g')]:
629654
data_file = os.path.join(args.cache_path, "{}.zip".format(key))
630655
url = TASK2PATH[key]
631656
reader = TASK2READER[key]
632-
download(url, data_file)
657+
download(url, data_file, sha1_hash=_URL_FILE_STATS[url])
633658
with zipfile.ZipFile(data_file) as zipdata:
634659
zipdata.extractall(args.data_dir)
635660
df = reader(os.path.join(args.data_dir, name))
636-
df.to_pickle(os.path.join(args.data_dir, name, '{}.pd.pkl'.format(name)))
661+
df.to_parquet(os.path.join(args.data_dir, name, '{}.parquet'.format(name)))
637662
elif task == 'mrpc':
638663
reader = TASK2READER[task]
639664
format_mrpc(args.data_dir)
640665
df_dict, meta_data = reader(os.path.join(args.data_dir, 'mrpc'))
641666
for key, df in df_dict.items():
642667
if key == 'val':
643668
key = 'dev'
644-
df.to_pickle(os.path.join(args.data_dir, 'mrpc', '{}.pd.pkl'.format(key)))
669+
df.to_parquet(os.path.join(args.data_dir, 'mrpc', '{}.parquet'.format(key)))
645670
with open(os.path.join(args.data_dir, 'mrpc', 'metadata.json'), 'w') as f:
646671
json.dump(meta_data, f)
647672
else:
648673
# Download data
649674
data_file = os.path.join(args.cache_path, "{}.zip".format(task))
650675
url = TASK2PATH[task]
651676
reader = TASK2READER[task]
652-
download(url, data_file)
677+
download(url, data_file, sha1_hash=_URL_FILE_STATS[url])
653678
base_dir = os.path.join(args.data_dir, task)
679+
if os.path.exists(base_dir):
680+
print('Found!')
681+
continue
654682
zip_dir_name = None
655683
with zipfile.ZipFile(data_file) as zipdata:
656684
if zip_dir_name is None:
@@ -662,7 +690,7 @@ def main(args):
662690
for key, df in df_dict.items():
663691
if key == 'val':
664692
key = 'dev'
665-
df.to_pickle(os.path.join(base_dir, '{}.pd.pkl'.format(key)))
693+
df.to_parquet(os.path.join(base_dir, '{}.parquet'.format(key)))
666694
if meta_data is not None:
667695
with open(os.path.join(base_dir, 'metadata.json'), 'w') as f:
668696
json.dump(meta_data, f)

scripts/question_answering/run_squad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,8 @@ def train(args):
563563
segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
564564
valid_length = sample.valid_length.as_in_ctx(ctx)
565565
p_mask = sample.masks.as_in_ctx(ctx)
566-
gt_start = sample.gt_start.as_in_ctx(ctx)
567-
gt_end = sample.gt_end.as_in_ctx(ctx)
566+
gt_start = sample.gt_start.as_in_ctx(ctx).astype(np.int32)
567+
gt_end = sample.gt_end.as_in_ctx(ctx).astype(np.int32)
568568
is_impossible = sample.is_impossible.as_in_ctx(ctx).astype(np.int32)
569569
batch_idx = mx.np.arange(tokens.shape[0], dtype=np.int32, ctx=ctx)
570570
p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask

0 commit comments

Comments
 (0)