1111# See the License for the specific language governing permissions and
1212# limitations under the License.
1313
14- import bz2
15- import os
16- import urllib
1714import sys
18- import zipfile
19- import io
15+ import wget
2016
21- URLLIB = urllib
22- if sys .version_info >= (3 , 0 ):
23- URLLIB = urllib .request
17+ from pathlib import Path
2418
25- class GLUEDownloader :
26- def __init__ (self , task , save_path ):
27-
28- # Documentation - Download link obtained from here: https://github.com/nyu-mll/GLUE-baselines/blob/master/download_glue_data.py
29-
30- self .TASK2PATH = {"CoLA" :'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4' ,
31- "SST" :'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8' ,
32- "MRPC" :{"mrpc_dev" : 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc' ,
33- "mrpc_train" : 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' ,
34- "mrpc_test" : 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' },
35- "QQP" :'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5' ,
36- "STS" :'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5' ,
37- "MNLI" :'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce' ,
38- "SNLI" :'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df' ,
39- "QNLI" :'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLI.zip?alt=media&token=c24cad61-f2df-4f04-9ab6-aa576fa829d0' ,
40- "RTE" :'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb' ,
41- "WNLI" :'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf' ,
42- "diagnostic" :'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D' }
43-
44-
45- self .save_path = save_path
46- if not os .path .exists (self .save_path ):
47- os .makedirs (self .save_path )
48-
49- self .task = task
5019
51- def download (self ):
20+ def mkdir (path ):
21+ Path (path ).mkdir (parents = True , exist_ok = True )
5222
53- if self .task == 'MRPC' :
54- self .download_mrpc ()
55- elif self .task == 'diagnostic' :
56- self .download_diagnostic ()
57- else :
58- self .download_and_extract (self .task )
5923
60- def download_and_extract (self , task ):
61- print ("Downloading and extracting %s..." % task )
62- data_file = "%s.zip" % task
63- URLLIB .urlretrieve (self .TASK2PATH [task ], data_file )
64- print (data_file ,"\n \n \n " )
65- with zipfile .ZipFile (data_file ) as zip_ref :
66- zip_ref .extractall (self .save_path )
67- os .remove (data_file )
68- print ("\t Completed!" )
69-
70- def download_mrpc (self ):
71- print ("Processing MRPC..." )
72- mrpc_dir = os .path .join (self .save_path , "MRPC" )
73- if not os .path .isdir (mrpc_dir ):
74- os .mkdir (mrpc_dir )
75-
76- mrpc_train_file = os .path .join (mrpc_dir , "msr_paraphrase_train.txt" )
77- mrpc_dev_file = os .path .join (mrpc_dir , "dev_ids.tsv" )
78- mrpc_test_file = os .path .join (mrpc_dir , "msr_paraphrase_test.txt" )
79-
80- URLLIB .urlretrieve (self .TASK2PATH ["MRPC" ]["mrpc_train" ], mrpc_train_file )
81- URLLIB .urlretrieve (self .TASK2PATH ["MRPC" ]["mrpc_test" ], mrpc_test_file )
82- URLLIB .urlretrieve (self .TASK2PATH ["MRPC" ]["mrpc_dev" ], mrpc_dev_file )
83-
84- dev_ids = []
85- with io .open (os .path .join (mrpc_dir , "dev_ids.tsv" ), encoding = 'utf-8' ) as ids_fh :
86- for row in ids_fh :
87- dev_ids .append (row .strip ().split ('\t ' ))
88-
89- with io .open (mrpc_train_file , encoding = 'utf-8' ) as data_fh , \
90- io .open (os .path .join (mrpc_dir , "train.tsv" ), 'w' , encoding = 'utf-8' ) as train_fh , \
91- io .open (os .path .join (mrpc_dir , "dev.tsv" ), 'w' , encoding = 'utf-8' ) as dev_fh :
92- header = data_fh .readline ()
93- train_fh .write (header )
94- dev_fh .write (header )
95- for row in data_fh :
96- label , id1 , id2 , s1 , s2 = row .strip ().split ('\t ' )
97- if [id1 , id2 ] in dev_ids :
98- dev_fh .write ("%s\t %s\t %s\t %s\t %s\n " % (label , id1 , id2 , s1 , s2 ))
99- else :
100- train_fh .write ("%s\t %s\t %s\t %s\t %s\n " % (label , id1 , id2 , s1 , s2 ))
24+ class GLUEDownloader :
10125
102- with io .open (mrpc_test_file , encoding = 'utf-8' ) as data_fh , \
103- io .open (os .path .join (mrpc_dir , "test.tsv" ), 'w' , encoding = 'utf-8' ) as test_fh :
104- header = data_fh .readline ()
105- test_fh .write ("index\t #1 ID\t #2 ID\t #1 String\t #2 String\n " )
106- for idx , row in enumerate (data_fh ):
107- label , id1 , id2 , s1 , s2 = row .strip ().split ('\t ' )
108- test_fh .write ("%d\t %s\t %s\t %s\t %s\n " % (idx , id1 , id2 , s1 , s2 ))
109- print ("\t Completed!" )
26+ def __init__ (self , save_path ):
27+ self .save_path = save_path + '/glue'
28+
29+ def download (self , task_name ):
30+ mkdir (self .save_path )
31+ if task_name in {'mrpc' , 'mnli' }:
32+ task_name = task_name .upper ()
33+ elif task_name == 'cola' :
34+ task_name = 'CoLA'
35+ else : # SST-2
36+ assert task_name == 'sst-2'
37+ task_name = 'SST'
38+ wget .download (
39+ 'https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py' ,
40+ out = self .save_path ,
41+ )
42+ sys .path .append (self .save_path )
43+ import download_glue_data
44+ download_glue_data .main (
45+ ['--data_dir' , self .save_path , '--tasks' , task_name ])
46+ sys .path .pop ()
0 commit comments