Skip to content

Commit 1069a73

Browse files
author
Swetha Mandava
committed
converge to pyt
1 parent efd6384 commit 1069a73

File tree

5 files changed

+45
-108
lines changed

5 files changed

+45
-108
lines changed

TensorFlow/LanguageModeling/BERT/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ FROM ${FROM_IMAGE_NAME}
55
RUN apt-get update && apt-get install -y pbzip2 pv bzip2 libcurl4 curl libb64-dev
66
RUN pip install --upgrade pip
77
RUN pip install toposort networkx pytest nltk tqdm html2text progressbar
8-
RUN pip --no-cache-dir --no-cache install git+https://github.com/NVIDIA/dllogger
8+
RUN pip --no-cache-dir --no-cache install git+https://github.com/NVIDIA/dllogger wget
99

1010
WORKDIR /workspace
1111
RUN git clone https://github.com/openai/gradient-checkpointing.git

TensorFlow/LanguageModeling/BERT/data/Downloader.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ def download(self):
5353
elif self.dataset_name == 'nvidia_pretrained_weights':
5454
self.download_nvidia_pretrained_weights()
5555

56-
elif self.dataset_name == 'MRPC':
56+
elif self.dataset_name == 'mrpc':
5757
self.download_glue(self.dataset_name)
5858

59-
elif self.dataset_name == 'MNLI':
59+
elif self.dataset_name == 'mnli':
6060
self.download_glue(self.dataset_name)
6161

62-
elif self.dataset_name == 'CoLA':
62+
elif self.dataset_name == 'cola':
6363
self.download_glue(self.dataset_name)
64-
elif self.dataset_name == 'SST':
64+
elif self.dataset_name == 'sst-2':
6565
self.download_glue(self.dataset_name)
6666

6767
elif self.dataset_name == 'squad':
@@ -77,10 +77,10 @@ def download(self):
7777
self.download_pubmed('open_access')
7878
self.download_google_pretrained_weights()
7979
self.download_nvidia_pretrained_weights()
80-
self.download_glue("CoLA")
81-
self.download_glue("MNLI")
82-
self.download_glue("MRPC")
83-
self.download_glue("SST")
80+
self.download_glue("cola")
81+
self.download_glue("mnli")
82+
self.download_glue("mrpc")
83+
self.download_glue("sst-2")
8484
self.download_squad()
8585

8686
else:
@@ -114,8 +114,8 @@ def download_nvidia_pretrained_weights(self):
114114

115115

116116
def download_glue(self, glue_task_name):
117-
downloader = GLUEDownloader(glue_task_name, self.save_path)
118-
downloader.download()
117+
downloader = GLUEDownloader(self.save_path)
118+
downloader.download(glue_task_name)
119119

120120

121121
def download_squad(self):

TensorFlow/LanguageModeling/BERT/data/GLUEDownloader.py

Lines changed: 26 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -11,99 +11,36 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import bz2
15-
import os
16-
import urllib
1714
import sys
18-
import zipfile
19-
import io
15+
import wget
2016

21-
URLLIB=urllib
22-
if sys.version_info >= (3, 0):
23-
URLLIB=urllib.request
17+
from pathlib import Path
2418

25-
class GLUEDownloader:
26-
def __init__(self, task, save_path):
27-
28-
# Documentation - Download link obtained from here: https://github.com/nyu-mll/GLUE-baselines/blob/master/download_glue_data.py
29-
30-
self.TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
31-
"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
32-
"MRPC":{"mrpc_dev": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
33-
"mrpc_train": 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt',
34-
"mrpc_test": 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'},
35-
"QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
36-
"STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
37-
"MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
38-
"SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
39-
"QNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLI.zip?alt=media&token=c24cad61-f2df-4f04-9ab6-aa576fa829d0',
40-
"RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
41-
"WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
42-
"diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
43-
44-
45-
self.save_path = save_path
46-
if not os.path.exists(self.save_path):
47-
os.makedirs(self.save_path)
48-
49-
self.task = task
5019

51-
def download(self):
20+
def mkdir(path):
21+
Path(path).mkdir(parents=True, exist_ok=True)
5222

53-
if self.task == 'MRPC':
54-
self.download_mrpc()
55-
elif self.task == 'diagnostic':
56-
self.download_diagnostic()
57-
else:
58-
self.download_and_extract(self.task)
5923

60-
def download_and_extract(self, task):
61-
print("Downloading and extracting %s..." % task)
62-
data_file = "%s.zip" % task
63-
URLLIB.urlretrieve(self.TASK2PATH[task], data_file)
64-
print(data_file,"\n\n\n")
65-
with zipfile.ZipFile(data_file) as zip_ref:
66-
zip_ref.extractall(self.save_path)
67-
os.remove(data_file)
68-
print("\tCompleted!")
69-
70-
def download_mrpc(self):
71-
print("Processing MRPC...")
72-
mrpc_dir = os.path.join(self.save_path, "MRPC")
73-
if not os.path.isdir(mrpc_dir):
74-
os.mkdir(mrpc_dir)
75-
76-
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
77-
mrpc_dev_file = os.path.join(mrpc_dir, "dev_ids.tsv")
78-
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
79-
80-
URLLIB.urlretrieve(self.TASK2PATH["MRPC"]["mrpc_train"], mrpc_train_file)
81-
URLLIB.urlretrieve(self.TASK2PATH["MRPC"]["mrpc_test"], mrpc_test_file)
82-
URLLIB.urlretrieve(self.TASK2PATH["MRPC"]["mrpc_dev"], mrpc_dev_file)
83-
84-
dev_ids = []
85-
with io.open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding='utf-8') as ids_fh:
86-
for row in ids_fh:
87-
dev_ids.append(row.strip().split('\t'))
88-
89-
with io.open(mrpc_train_file, encoding='utf-8') as data_fh, \
90-
io.open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf-8') as train_fh, \
91-
io.open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf-8') as dev_fh:
92-
header = data_fh.readline()
93-
train_fh.write(header)
94-
dev_fh.write(header)
95-
for row in data_fh:
96-
label, id1, id2, s1, s2 = row.strip().split('\t')
97-
if [id1, id2] in dev_ids:
98-
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
99-
else:
100-
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
24+
class GLUEDownloader:
10125

102-
with io.open(mrpc_test_file, encoding='utf-8') as data_fh, \
103-
io.open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf-8') as test_fh:
104-
header = data_fh.readline()
105-
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
106-
for idx, row in enumerate(data_fh):
107-
label, id1, id2, s1, s2 = row.strip().split('\t')
108-
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
109-
print("\tCompleted!")
26+
def __init__(self, save_path):
27+
self.save_path = save_path + '/glue'
28+
29+
def download(self, task_name):
30+
mkdir(self.save_path)
31+
if task_name in {'mrpc', 'mnli'}:
32+
task_name = task_name.upper()
33+
elif task_name == 'cola':
34+
task_name = 'CoLA'
35+
else: # SST-2
36+
assert task_name == 'sst-2'
37+
task_name = 'SST'
38+
wget.download(
39+
'https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py',
40+
out=self.save_path,
41+
)
42+
sys.path.append(self.save_path)
43+
import download_glue_data
44+
download_glue_data.main(
45+
['--data_dir', self.save_path, '--tasks', task_name])
46+
sys.path.pop()

TensorFlow/LanguageModeling/BERT/data/bertPrep.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def main(args):
6262

6363
elif args.action == 'text_formatting':
6464
assert args.dataset != 'google_pretrained_weights' and args.dataset != 'nvidia_pretrained_weights' \
65-
and args.dataset != 'squad' and args.dataset != 'MRPC' and args.dataset != 'CoLA' and \
66-
args.dataset != 'MNLI' and args.dataset != 'SST', 'Cannot perform text_formatting on pretrained weights'
65+
and args.dataset != 'squad' and args.dataset != 'mrpc' and args.dataset != 'cola' and \
66+
args.dataset != 'mnli' and args.dataset != 'sst-2', 'Cannot perform text_formatting on pretrained weights'
6767

6868
if not os.path.exists(directory_structure['extracted']):
6969
os.makedirs(directory_structure['extracted'])
@@ -271,10 +271,10 @@ def create_record_worker(filename_prefix, shard_id, output_format='hdf5'):
271271
'google_pretrained_weights',
272272
'nvidia_pretrained_weights',
273273
'squad',
274-
'MRPC',
275-
'CoLA',
276-
'MNLI',
277-
'SST',
274+
'mrpc',
275+
'sst-2',
276+
'mnli',
277+
'cola',
278278
'all'
279279
}
280280
)

TensorFlow/LanguageModeling/BERT/data/create_datasets_from_start.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ fi
2525
python3 /workspace/bert/data/bertPrep.py --action download --dataset wikicorpus_en
2626
python3 /workspace/bert/data/bertPrep.py --action download --dataset google_pretrained_weights # Includes vocab
2727
python3 /workspace/bert/data/bertPrep.py --action download --dataset squad
28-
python3 /workspace/bert/data/bertPrep.py --action download --dataset MRPC
29-
python3 /workspace/bert/data/bertPrep.py --action download --dataset SST
28+
python3 /workspace/bert/data/bertPrep.py --action download --dataset mrpc
29+
python3 /workspace/bert/data/bertPrep.py --action download --dataset sst-2
3030

3131
# Properly format the text files
3232
if [ "$to_download" = "wiki_books" ] ; then

0 commit comments

Comments
 (0)