Skip to content

Commit bd8f332

Browse files
Jorge Fernandez HernandezJorge Fernandez Hernandez
authored andcommitted
GAIAMNGT-1700 change signature of the method Gaia.load_data
1 parent 26ce1d7 commit bd8f332

File tree

3 files changed

+42
-39
lines changed

3 files changed

+42
-39
lines changed

astroquery/gaia/core.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import zipfile
2020
from collections.abc import Iterable
2121
from datetime import datetime, timezone
22-
from pathlib import Path
2322

2423
from astropy import units
2524
from astropy import units as u
@@ -170,7 +169,7 @@ def logout(self, *, verbose=False):
170169

171170
def load_data(self, ids, *, data_release=None, data_structure='INDIVIDUAL', retrieval_type="ALL",
172171
linking_parameter='SOURCE_ID', valid_data=False, band=None, avoid_datatype_check=False,
173-
format="votable", output_file=None, overwrite_output_file=False, verbose=False):
172+
format="votable", dump_to_file=False, overwrite_output_file=False, verbose=False):
174173
"""Loads the specified table
175174
TAP+ only
176175
@@ -217,10 +216,8 @@ def load_data(self, ids, *, data_release=None, data_structure='INDIVIDUAL', retr
217216
By default, this value will be set to False. If it is set to 'true'
218217
the Datalink items tags will not be checked.
219218
format : str, optional, default 'votable'
220-
loading format. Other available formats are 'csv', 'ecsv','votable_plain' and 'fits'
221-
output_file : string or pathlib.PosixPath, optional, default None
222-
file where the results are saved.
223-
If it is not provided, the http response contents are returned.
219+
loading format. Other available formats are 'csv', 'ecsv','votable_plain', 'json' and 'fits'
220+
dump_to_file: boolean, optional, default False
224221
overwrite_output_file : boolean, optional, default False
225222
To overwrite the output_file if it already exists.
226223
verbose : bool, optional, default 'False'
@@ -235,26 +232,30 @@ def load_data(self, ids, *, data_release=None, data_structure='INDIVIDUAL', retr
235232
temp_dirname = "temp_" + now_formatted
236233
downloadname_formated = "download_" + now_formatted
237234

235+
overwrite_output_file = True
238236
output_file_specified = False
239-
if output_file is None:
237+
238+
if not dump_to_file:
240239
output_file = os.path.join(os.getcwd(), temp_dirname, downloadname_formated)
241240
else:
241+
output_file = 'datalink_output.zip'
242242
output_file_specified = True
243-
244-
if isinstance(output_file, str):
245-
if not output_file.lower().endswith('.zip'):
246-
output_file = output_file + '.zip'
247-
elif isinstance(output_file, Path):
248-
if not output_file.suffix.endswith('.zip'):
249-
output_file.with_suffix('.zip')
250-
251243
output_file = os.path.abspath(output_file)
252244
if not overwrite_output_file and os.path.exists(output_file):
253-
raise ValueError(f"{output_file} file already exists. Please use overwrite_output_file='True' to "
254-
f"overwrite output file.")
245+
print(f"{output_file} file already exists and will be overwritten")
255246

256247
path = os.path.dirname(output_file)
257248

249+
log.debug(f"Directory where the data will be saved: {path}")
250+
251+
if path != '':
252+
try:
253+
os.mkdir(path)
254+
except FileExistsError:
255+
log.error("Path %s already exist" % path)
256+
except OSError:
257+
log.error("Creation of the directory %s failed" % path)
258+
258259
if avoid_datatype_check is False:
259260
# we need to check params
260261
rt = str(retrieval_type).upper()
@@ -297,14 +298,7 @@ def load_data(self, ids, *, data_release=None, data_structure='INDIVIDUAL', retr
297298
if linking_parameter != 'SOURCE_ID':
298299
params_dict['LINKING_PARAMETER'] = linking_parameter
299300

300-
if path != '':
301-
try:
302-
os.mkdir(path)
303-
except FileExistsError:
304-
log.error("Path %s already exist" % path)
305-
except OSError:
306-
log.error("Creation of the directory %s failed" % path)
307-
301+
files = dict()
308302
try:
309303
self.__gaiadata.load_data(params_dict=params_dict, output_file=output_file, verbose=verbose)
310304
files = Gaia.__get_data_files(output_file=output_file, path=path)
@@ -313,6 +307,10 @@ def load_data(self, ids, *, data_release=None, data_structure='INDIVIDUAL', retr
313307
finally:
314308
if not output_file_specified:
315309
shutil.rmtree(path)
310+
else:
311+
for file in files.keys():
312+
os.remove(os.path.join(os.getcwd(), path, file)
313+
)
316314

317315
if verbose:
318316
if output_file_specified:
@@ -328,17 +326,21 @@ def load_data(self, ids, *, data_release=None, data_structure='INDIVIDUAL', retr
328326
@staticmethod
329327
def __get_data_files(output_file, path):
330328
files = {}
331-
if zipfile.is_zipfile(output_file):
332-
with zipfile.ZipFile(output_file, 'r') as zip_ref:
333-
zip_ref.extractall(os.path.dirname(output_file))
329+
extracted_files = []
330+
331+
with zipfile.ZipFile(output_file, "r") as zip_ref:
332+
for name in zip_ref.namelist():
333+
local_file_path = zip_ref.extract(name, os.path.dirname(output_file))
334+
extracted_files.append(local_file_path)
334335

335336
# r=root, d=directories, f = files
336337
for r, d, f in os.walk(path):
337338
for file in f:
338-
if file.lower().endswith(('.fits', '.xml', '.csv', '.ecsv')):
339+
if file in extracted_files:
339340
files[file] = os.path.join(r, file)
340341

341342
for key, value in files.items():
343+
342344
if '.fits' in key:
343345
tables = []
344346
with fits.open(value) as hduList:
@@ -348,6 +350,7 @@ def __get_data_files(output_file, path):
348350
Gaia.correct_table_units(table)
349351
tables.append(table)
350352
files[key] = tables
353+
351354
elif '.xml' in key:
352355
tables = []
353356
for table in votable.parse(value).iter_tables():

astroquery/gaia/tests/test_gaiatap.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
705705
"RETRIEVAL_TYPE": "epoch_photometry",
706706
"DATA_STRUCTURE": "INDIVIDUAL",
707707
"USE_ZIP_ALWAYS": "true"}
708-
assert output_file == str(tmp_path / "output_file")
708+
assert output_file == str(tmp_path / "datalink_output.zip")
709709
assert verbose is True
710710

711711
monkeypatch.setattr(TapPlus, "load_data", load_data_monkeypatched)
@@ -715,7 +715,7 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
715715
retrieval_type="epoch_photometry",
716716
valid_data=True,
717717
verbose=True,
718-
output_file=tmp_path / "output_file")
718+
dump_to_file=True)
719719

720720

721721
def test_load_data_ecsv(monkeypatch, tmp_path):
@@ -727,7 +727,7 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
727727
"RETRIEVAL_TYPE": "epoch_photometry",
728728
"DATA_STRUCTURE": "INDIVIDUAL",
729729
"USE_ZIP_ALWAYS": "true"}
730-
assert output_file == str(tmp_path / "output_file.zip")
730+
assert output_file == str(tmp_path / "datalink_output.zip")
731731
assert verbose is True
732732

733733
monkeypatch.setattr(TapPlus, "load_data", load_data_monkeypatched)
@@ -738,7 +738,7 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
738738
valid_data=True,
739739
verbose=True,
740740
format='ecsv',
741-
output_file=str(tmp_path / "output_file"))
741+
dump_to_file=True)
742742

743743

744744
def test_load_data_linking_parameter(monkeypatch, tmp_path):
@@ -750,7 +750,7 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
750750
"RETRIEVAL_TYPE": "epoch_photometry",
751751
"DATA_STRUCTURE": "INDIVIDUAL",
752752
"USE_ZIP_ALWAYS": "true"}
753-
assert output_file == str(tmp_path / "output_file")
753+
assert output_file == str(tmp_path / "datalink_output.zip")
754754
assert verbose is True
755755

756756
monkeypatch.setattr(TapPlus, "load_data", load_data_monkeypatched)
@@ -761,7 +761,7 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
761761
linking_parameter="SOURCE_ID",
762762
valid_data=True,
763763
verbose=True,
764-
output_file=tmp_path / "output_file")
764+
dump_to_file=True)
765765

766766

767767
@pytest.mark.parametrize("linking_param", ['TRANSIT_ID', 'IMAGE_ID'])
@@ -774,8 +774,8 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
774774
"RETRIEVAL_TYPE": "epoch_photometry",
775775
"DATA_STRUCTURE": "INDIVIDUAL",
776776
"LINKING_PARAMETER": linking_param,
777-
"USE_ZIP_ALWAYS": "true"}
778-
assert output_file == str(tmp_path / "output_file")
777+
"USE_ZIP_ALWAYS": "true", }
778+
assert output_file == str(tmp_path / "datalink_output.zip")
779779
assert verbose is True
780780

781781
monkeypatch.setattr(TapPlus, "load_data", load_data_monkeypatched)
@@ -786,7 +786,7 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
786786
linking_parameter=linking_param,
787787
valid_data=True,
788788
verbose=True,
789-
output_file=tmp_path / "output_file")
789+
dump_to_file=True)
790790

791791

792792
def test_get_datalinks(monkeypatch):

astroquery/utils/tap/xmlparser/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def read_http_response(response, output_format, *, correct_units=True, use_names
7676
else:
7777
with warnings.catch_warnings():
7878
# Capturing the warning and converting the objid column to int64 is necessary for consistency as
79-
# it was convereted to string on systems with defaul integer int32 due to an overflow.
79+
# it was converted to string on systems with default integer int32 due to an overflow.
8080
if sys.platform.startswith('win'):
8181
warnings.filterwarnings("ignore", category=AstropyWarning,
8282
message=r'OverflowError converting to IntType in column.*')

0 commit comments

Comments
 (0)