@@ -87,13 +87,19 @@ def download(self, id_list, file_format=None, ids=None, table_name='', download_
8787 :param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed.
8888 :param table_name: restricts download to specific tables.
8989 :param download_dir: defaults to ./hepdata-downloads. Specifies where to download the files.
90+
91+ :return: dictionary mapping id to list of downloaded files.
92+ :rtype: dict[int, list[str]]
9093 """
9194
92- urls = self ._build_urls (id_list , file_format , ids , table_name )
93- for url in urls :
95+ url_map = self ._build_urls (id_list , file_format , ids , table_name )
96+ file_map = {}
97+ for record_id , url in url_map .items ():
9498 if self .verbose is True :
9599 print ("Downloading: " + url )
96- download_url (url , download_dir )
100+ files_downloaded = download_url (url , download_dir )
101+ file_map [record_id ] = files_downloaded
102+ return file_map
97103
98104 def fetch_names (self , id_list , ids = None ):
99105 """
@@ -102,9 +108,9 @@ def fetch_names(self, id_list, ids=None):
102108 :param id_list: list of id of records of which to return table names.
103109 :param ids: accepts one of ('inspire', 'hepdata'). It specifies what type of ids have been passed.
104110 """
105- urls = self ._build_urls (id_list , 'json' , ids , '' )
111+ url_map = self ._build_urls (id_list , 'json' , ids , '' )
106112 table_names = []
107- for url in urls :
113+ for url in url_map . values () :
108114 response = resilient_requests ('get' , url )
109115 json_dict = response .json ()
110116 table_names += [[data_table ['name' ] for data_table in json_dict ['data_tables' ]]]
@@ -136,7 +142,16 @@ def upload(self, path_to_file, email, recid=None, invitation_cookie=None, sandbo
136142 print ('Uploaded ' + path_to_file + ' to ' + SITE_URL + '/record/' + str (recid ))
137143
138144 def _build_urls (self , id_list , file_format , ids , table_name ):
139- """Builds urls for download and fetch_names, given the specified parameters."""
145+ """
146+ Builds urls for download and fetch_names, given the specified parameters.
147+
148+ :param id_list: list of ids to download.
149+ :param file_format: accepts one of ('csv', 'root', 'yaml', 'yoda', 'yoda1', 'yoda.h5', 'json').
150+ :param ids: accepts one of ('inspire', 'hepdata').
151+ :param table_name: restricts download to specific tables.
152+
153+ :return: dictionary mapping id to url.
154+ """
140155 if type (id_list ) not in (tuple , list ):
141156 id_list = id_list .split ()
142157 assert len (id_list ) > 0 , 'Ids are required.'
@@ -146,9 +161,12 @@ def _build_urls(self, id_list, file_format, ids, table_name):
146161 params = {'format' : file_format }
147162 else :
148163 params = {'format' : file_format , 'table' : table_name }
149- urls = [resilient_requests ('get' , SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '' ) + id_entry , params = params ).url .replace ('%2525' , '%25' ) for id_entry in id_list ]
164+ url_mapping = {}
165+ for id_entry in id_list :
166+ url = resilient_requests ('get' , SITE_URL + '/record/' + ('ins' if ids == 'inspire' else '' ) + id_entry , params = params ).url .replace ('%2525' , '%25' )
167+ url_mapping [id_entry ] = url
150168 # TODO: Investigate root cause of double URL encoding (https://github.com/HEPData/hepdata-cli/issues/8).
151- return urls
169+ return url_mapping
152170
153171 def _query (self , query , page , size ):
154172 """Builds the search query passed to hepdata.net."""
@@ -170,6 +188,7 @@ def mkdir(directory):
170188
171189def download_url (url , download_dir ):
172190 """Download file and if necessary extract it."""
191+ files_downloaded = []
173192 assert is_downloadable (url ), "Given url is not downloadable: {}" .format (url )
174193 response = resilient_requests ('get' , url , allow_redirects = True )
175194 if url [- 4 :] == 'json' :
@@ -182,10 +201,31 @@ def download_url(url, download_dir):
182201 mkdir (os .path .dirname (filepath ))
183202 open (filepath , 'wb' ).write (response .content )
184203 if filepath .endswith ("tar.gz" ) or filepath .endswith ("tar" ):
185- tar = tarfile .open (filepath , "r:gz" if filepath .endswith ("tar.gz" ) else "r:" )
186- tar .extractall (path = os .path .dirname (filepath ))
187- tar .close ()
188- os .remove (filepath )
204+ tar = None
205+ try :
206+ tar = tarfile .open (filepath , "r:gz" if filepath .endswith ("tar.gz" ) else "r:" )
207+ extract_dir = os .path .abspath (os .path .dirname (filepath ))
208+ tar .extractall (path = os .path .dirname (filepath ))
209+ for member in tar .getmembers ():
210+ if member .isfile ():
211+ extracted_path = os .path .join (os .path .dirname (filepath ), member .name )
212+ abs_extracted_path = os .path .abspath (extracted_path )
213+ if abs_extracted_path .startswith (extract_dir + os .sep ) and os .path .exists (abs_extracted_path ):
214+ files_downloaded .append (abs_extracted_path )
215+ elif not abs_extracted_path .startswith (extract_dir + os .sep ):
216+ raise ValueError (f"Attempted path traversal for file { member .name } " )
217+ else :
218+ raise FileNotFoundError (f"Extracted file { member .name } not found" )
219+ except Exception as e :
220+ raise Exception (f"Failed to extract { filepath } : { str (e )} " )
221+ finally :
222+ if tar :
223+ tar .close ()
224+ if os .path .exists (filepath ):
225+ os .remove (filepath )
226+ else :
227+ files_downloaded .append (filepath )
228+ return files_downloaded
189229
190230
191231def getFilename_fromCd (cd ):
0 commit comments