Skip to content

Commit 80275c2

Browse files
Jorge Fernandez HernandezJorge Fernandez Hernandez
authored andcommitted
GAIAMNGT-1700 New tests
1 parent 040b1dc commit 80275c2

8 files changed

+179
-30
lines changed

astroquery/gaia/core.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,8 @@ def load_data(self, ids, *, data_release=None, data_structure='INDIVIDUAL', retr
218218
format : str, optional, default 'votable'
219219
loading format. Other available formats are 'csv', 'ecsv','votable_plain', 'json' and 'fits'
220220
dump_to_file: boolean, optional, default False.
221-
If it is true, a compressed directory named "datalink_output.zip" with all the DataLink files is made
221+
If it is true, a compressed directory named "datalink_output.zip" with all the DataLink files is made in the
222+
current working directory
222223
overwrite_output_file : boolean, optional, default False
223224
To overwrite the output_file if it already exists.
224225
verbose : bool, optional, default 'False'
@@ -242,22 +243,23 @@ def load_data(self, ids, *, data_release=None, data_structure='INDIVIDUAL', retr
242243
output_file = 'datalink_output.zip'
243244
output_file_specified = True
244245
output_file = os.path.abspath(output_file)
245-
print(f"DataLink products are stored inside the {output_file} file")
246+
log.info(f"DataLink products will be stored in the {output_file} file")
246247

247248
if not overwrite_output_file and os.path.exists(output_file):
248-
print(f"{output_file} file already exists and will be overwritten")
249+
log.warn(f"{output_file} file already exists and will be overwritten")
249250

250251
path = os.path.dirname(output_file)
251252

252253
log.debug(f"Directory where the data will be saved: {path}")
253254

254255
if path != '':
255-
try:
256-
os.mkdir(path)
257-
except FileExistsError:
258-
log.error("Path %s already exist" % path)
259-
except OSError:
260-
log.error("Creation of the directory %s failed" % path)
256+
if not os.path.isdir(path):
257+
try:
258+
os.mkdir(path)
259+
except FileExistsError:
260+
log.warn("Path %s already exist" % path)
261+
except OSError:
262+
log.error("Creation of the directory %s failed" % path)
261263

262264
if avoid_datatype_check is False:
263265
# we need to check params
@@ -312,8 +314,7 @@ def load_data(self, ids, *, data_release=None, data_structure='INDIVIDUAL', retr
312314
shutil.rmtree(path)
313315
else:
314316
for file in files.keys():
315-
os.remove(os.path.join(os.getcwd(), path, file)
316-
)
317+
os.remove(os.path.join(os.getcwd(), path, file))
317318

318319
if verbose:
319320
if output_file_specified:
@@ -332,9 +333,8 @@ def __get_data_files(output_file, path):
332333
extracted_files = []
333334

334335
with zipfile.ZipFile(output_file, "r") as zip_ref:
335-
for name in zip_ref.namelist():
336-
local_file_path = zip_ref.extract(name, os.path.dirname(output_file))
337-
extracted_files.append(local_file_path)
336+
extracted_files.extend(zip_ref.namelist())
337+
zip_ref.extractall(os.path.dirname(output_file))
338338

339339
# r=root, d=directories, f = files
340340
for r, d, f in os.walk(path):
@@ -344,7 +344,7 @@ def __get_data_files(output_file, path):
344344

345345
for key, value in files.items():
346346

347-
if '.fits' in key:
347+
if key.endswith('.fits'):
348348
tables = []
349349
with fits.open(value) as hduList:
350350
num_hdus = len(hduList)
@@ -354,19 +354,19 @@ def __get_data_files(output_file, path):
354354
tables.append(table)
355355
files[key] = tables
356356

357-
elif '.xml' in key:
357+
elif key.endswith('.xml'):
358358
tables = []
359359
for table in votable.parse(value).iter_tables():
360360
tables.append(table)
361361
files[key] = tables
362362

363-
elif '.csv' in key:
363+
elif key.endswith('.csv'):
364364
tables = []
365365
table = Table.read(value, format='ascii.csv', fast_reader=False)
366366
tables.append(table)
367367
files[key] = tables
368368

369-
elif '.json' in key:
369+
elif key.endswith('.json'):
370370
tables = []
371371
with open(value) as f:
372372
data = json.load(f)

astroquery/gaia/tests/setup_package.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ def get_package_data():
1010
paths = [os.path.join('data', '*.vot'),
1111
os.path.join('data', '*.vot.gz'),
1212
os.path.join('data', '*.json'),
13-
os.path.join('data', '*.ecsv')
13+
os.path.join('data', '*.ecsv'),
14+
os.path.join('data', '*.zip')
1415
] # etc, add other extensions
1516
# you can also enlist files individually by names
1617
# finally construct and return a dict for the sub module

astroquery/gaia/tests/test_gaiatap.py

Lines changed: 134 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
1616
"""
1717
import os
18+
import zipfile
1819
from pathlib import Path
1920
from unittest.mock import patch
2021

@@ -45,6 +46,22 @@
4546
JOB_DATA_QUERIER_ECSV_FILE_NAME = get_pkg_data_filename(os.path.join("data", '1712337806100O-result.ecsv'),
4647
package=package)
4748

49+
DL_PRODUCTS_VOT = get_pkg_data_filename(
50+
os.path.join("data", 'gaia_dr3_source_id_5937083312263887616_dl_products_vot.zip'),
51+
package=package)
52+
53+
DL_PRODUCTS_ECSV = get_pkg_data_filename(
54+
os.path.join("data", 'gaia_dr3_source_id_5937083312263887616_dl_products_ecsv.zip'),
55+
package=package)
56+
57+
DL_PRODUCTS_CSV = get_pkg_data_filename(
58+
os.path.join("data", 'gaia_dr3_source_id_5937083312263887616_dl_products_csv.zip'),
59+
package=package)
60+
61+
DL_PRODUCTS_FITS = get_pkg_data_filename(
62+
os.path.join("data", 'gaia_dr3_source_id_5937083312263887616_dl_products_fits.zip'),
63+
package=package)
64+
4865
JOB_DATA = Path(JOB_DATA_FILE_NAME).read_text()
4966
JOB_DATA_NEW = Path(JOB_DATA_FILE_NAME_NEW).read_text()
5067

@@ -152,6 +169,23 @@ def mock_querier():
152169
return GaiaClass(tap_plus_conn_handler=conn_handler, datalink_handler=tapplus, show_server_messages=False)
153170

154171

172+
@pytest.fixture(scope="module")
173+
def mock_datalink_querier():
174+
conn_handler = DummyConnHandler()
175+
tapplus = TapPlus(url="http://test:1111/tap", connhandler=conn_handler)
176+
177+
launch_response = DummyResponse(200)
178+
launch_response.set_data(method="POST", body=DL_PRODUCTS_VOT)
179+
# The query contains decimals: default response is more robust.
180+
conn_handler.set_default_response(launch_response)
181+
conn_handler.set_response(
182+
'?DATA_STRUCTURE=INDIVIDUAL&FORMAT=votable&ID=5937083312263887616&RELEASE=Gaia+DR3&RETRIEVAL_TYPE=ALL'
183+
'&USE_ZIP_ALWAYS=true&VALID_DATA=false',
184+
launch_response)
185+
186+
return GaiaClass(tap_plus_conn_handler=conn_handler, datalink_handler=tapplus, show_server_messages=False)
187+
188+
155189
@pytest.fixture(scope="module")
156190
def mock_querier_ecsv():
157191
conn_handler = DummyConnHandler()
@@ -696,29 +730,100 @@ def test_cone_search_and_changing_MAIN_GAIA_TABLE(mock_querier_async):
696730
assert "name_from_class" in job.parameters["query"]
697731

698732

699-
def test_load_data(monkeypatch, tmp_path):
733+
def test_load_data(mock_datalink_querier):
734+
mock_datalink_querier.load_data(ids=[5937083312263887616], data_release='Gaia DR3', data_structure='INDIVIDUAL',
735+
retrieval_type="ALL",
736+
linking_parameter='SOURCE_ID', valid_data=False, band=None,
737+
avoid_datatype_check=False,
738+
format="votable", dump_to_file=True, overwrite_output_file=True, verbose=False)
739+
740+
assert os.path.exists('datalink_output.zip')
741+
742+
extracted_files = []
743+
744+
with zipfile.ZipFile('datalink_output.zip', "r") as zip_ref:
745+
extracted_files.extend(zip_ref.namelist())
746+
747+
assert len(extracted_files) == 3
748+
749+
os.remove(os.path.join(os.getcwd(), 'datalink_output.zip'))
750+
751+
assert not os.path.exists('datalink_output.zip')
752+
753+
754+
@pytest.mark.skip(reason="Thes fits files generate an error relatate to the unit 'log(cm.s**-2)")
755+
def test_load_data_fits(monkeypatch, tmp_path, tmp_path_factory):
756+
path = Path(os.getcwd() + '/' + 'datalink_output.zip')
757+
758+
with open(DL_PRODUCTS_FITS, 'rb') as file:
759+
zip_bytes = file.read()
760+
761+
path.write_bytes(zip_bytes)
762+
700763
def load_data_monkeypatched(self, params_dict, output_file, verbose):
701764
assert params_dict == {
702765
"VALID_DATA": "true",
703766
"ID": "1,2,3,4",
704-
"FORMAT": "votable",
767+
"FORMAT": "fits",
705768
"RETRIEVAL_TYPE": "epoch_photometry",
706769
"DATA_STRUCTURE": "INDIVIDUAL",
707770
"USE_ZIP_ALWAYS": "true"}
708-
assert output_file == str(tmp_path / "datalink_output.zip")
771+
assert output_file == os.getcwd() + '/' + 'datalink_output.zip'
709772
assert verbose is True
710773

711774
monkeypatch.setattr(TapPlus, "load_data", load_data_monkeypatched)
712775

713776
GAIA_QUERIER.load_data(
777+
valid_data=True,
714778
ids="1,2,3,4",
779+
format='fits',
715780
retrieval_type="epoch_photometry",
781+
verbose=True,
782+
dump_to_file=True)
783+
784+
path.unlink()
785+
786+
787+
def test_load_data_csv(monkeypatch, tmp_path, tmp_path_factory):
788+
path = Path(os.getcwd() + '/' + 'datalink_output.zip')
789+
790+
with open(DL_PRODUCTS_CSV, 'rb') as file:
791+
zip_bytes = file.read()
792+
793+
path.write_bytes(zip_bytes)
794+
795+
def load_data_monkeypatched(self, params_dict, output_file, verbose):
796+
assert params_dict == {
797+
"VALID_DATA": "true",
798+
"ID": "1,2,3,4",
799+
"FORMAT": "csv",
800+
"RETRIEVAL_TYPE": "epoch_photometry",
801+
"DATA_STRUCTURE": "INDIVIDUAL",
802+
"USE_ZIP_ALWAYS": "true"}
803+
assert output_file == os.getcwd() + '/' + 'datalink_output.zip'
804+
assert verbose is True
805+
806+
monkeypatch.setattr(TapPlus, "load_data", load_data_monkeypatched)
807+
808+
GAIA_QUERIER.load_data(
716809
valid_data=True,
810+
ids="1,2,3,4",
811+
format='csv',
812+
retrieval_type="epoch_photometry",
717813
verbose=True,
718814
dump_to_file=True)
719815

816+
path.unlink()
817+
818+
819+
def test_load_data_ecsv(monkeypatch, tmp_path, tmp_path_factory):
820+
path = Path(os.getcwd() + '/' + 'datalink_output.zip')
821+
822+
with open(DL_PRODUCTS_ECSV, 'rb') as file:
823+
zip_bytes = file.read()
824+
825+
path.write_bytes(zip_bytes)
720826

721-
def test_load_data_ecsv(monkeypatch, tmp_path):
722827
def load_data_monkeypatched(self, params_dict, output_file, verbose):
723828
assert params_dict == {
724829
"VALID_DATA": "true",
@@ -727,21 +832,30 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
727832
"RETRIEVAL_TYPE": "epoch_photometry",
728833
"DATA_STRUCTURE": "INDIVIDUAL",
729834
"USE_ZIP_ALWAYS": "true"}
730-
assert output_file == str(tmp_path / "datalink_output.zip")
835+
assert output_file == os.getcwd() + '/' + 'datalink_output.zip'
731836
assert verbose is True
732837

733838
monkeypatch.setattr(TapPlus, "load_data", load_data_monkeypatched)
734839

735840
GAIA_QUERIER.load_data(
841+
valid_data=True,
736842
ids="1,2,3,4",
843+
format='ecsv',
737844
retrieval_type="epoch_photometry",
738-
valid_data=True,
739845
verbose=True,
740-
format='ecsv',
741846
dump_to_file=True)
742847

848+
path.unlink()
849+
743850

744851
def test_load_data_linking_parameter(monkeypatch, tmp_path):
852+
path = Path(os.getcwd() + '/' + 'datalink_output.zip')
853+
854+
with open(DL_PRODUCTS_VOT, 'rb') as file:
855+
zip_bytes = file.read()
856+
857+
path.write_bytes(zip_bytes)
858+
745859
def load_data_monkeypatched(self, params_dict, output_file, verbose):
746860
assert params_dict == {
747861
"VALID_DATA": "true",
@@ -750,7 +864,7 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
750864
"RETRIEVAL_TYPE": "epoch_photometry",
751865
"DATA_STRUCTURE": "INDIVIDUAL",
752866
"USE_ZIP_ALWAYS": "true"}
753-
assert output_file == str(tmp_path / "datalink_output.zip")
867+
assert output_file == os.getcwd() + '/' + 'datalink_output.zip'
754868
assert verbose is True
755869

756870
monkeypatch.setattr(TapPlus, "load_data", load_data_monkeypatched)
@@ -763,9 +877,18 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
763877
verbose=True,
764878
dump_to_file=True)
765879

880+
path.unlink()
881+
766882

767883
@pytest.mark.parametrize("linking_param", ['TRANSIT_ID', 'IMAGE_ID'])
768884
def test_load_data_linking_parameter_with_values(monkeypatch, tmp_path, linking_param):
885+
path = Path(os.getcwd() + '/' + 'datalink_output.zip')
886+
887+
with open(DL_PRODUCTS_VOT, 'rb') as file:
888+
zip_bytes = file.read()
889+
890+
path.write_bytes(zip_bytes)
891+
769892
def load_data_monkeypatched(self, params_dict, output_file, verbose):
770893
assert params_dict == {
771894
"VALID_DATA": "true",
@@ -775,7 +898,7 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
775898
"DATA_STRUCTURE": "INDIVIDUAL",
776899
"LINKING_PARAMETER": linking_param,
777900
"USE_ZIP_ALWAYS": "true", }
778-
assert output_file == str(tmp_path / "datalink_output.zip")
901+
assert output_file == os.getcwd() + '/' + 'datalink_output.zip'
779902
assert verbose is True
780903

781904
monkeypatch.setattr(TapPlus, "load_data", load_data_monkeypatched)
@@ -788,6 +911,8 @@ def load_data_monkeypatched(self, params_dict, output_file, verbose):
788911
verbose=True,
789912
dump_to_file=True)
790913

914+
path.unlink()
915+
791916

792917
def test_get_datalinks(monkeypatch):
793918
def get_datalinks_monkeypatched(self, ids, linking_parameter, verbose):

astroquery/utils/tap/conn/tests/DummyResponse.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class DummyResponse:
2424
STATUS_MESSAGES = {200: "OK", 303: "OK", 500: "ERROR"}
2525

2626
def __init__(self, status_code=None):
27+
self.zip_bytes = None
2728
self.reason = ""
2829
self.set_status_code(status_code)
2930
self.index = 0
@@ -57,11 +58,29 @@ def read(self, size=None):
5758
if v is None:
5859
return None
5960
else:
61+
62+
if v.endswith('zip'):
63+
if self.zip_bytes is None:
64+
with open(v, 'rb') as file:
65+
self.zip_bytes = file.read()
66+
6067
if size is None or size < 0:
68+
69+
if v.endswith('zip'):
70+
return self.zip_bytes
71+
6172
# read all
6273
return v.encode(encoding='utf_8', errors='strict')
6374
else:
64-
bodyLength = len(v)
75+
is_zip = False
76+
77+
if v.endswith('zip'):
78+
is_zip = True
79+
bodyLength = len(self.zip_bytes)
80+
v = self.zip_bytes
81+
else:
82+
bodyLength = len(v)
83+
6584
if self.index < 0:
6685
return ""
6786
if size >= bodyLength:
@@ -73,7 +92,11 @@ def read(self, size=None):
7392
self.index = endPos
7493
if endPos >= (bodyLength - 1):
7594
self.index = -1
76-
return tmp.encode(encoding='utf_8', errors='strict')
95+
96+
if is_zip:
97+
return tmp
98+
else:
99+
return tmp.encode(encoding='utf_8', errors='strict')
77100

78101
def close(self):
79102
self.index = 0

0 commit comments

Comments
 (0)