Skip to content

Commit d026f2f

Browse files
zoghbi-absipocz
authored andcommitted
fix sciserver downloads in heasarc
1 parent 013261c commit d026f2f

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

astroquery/heasarc/core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,16 +692,17 @@ def _copy_sciserver(self, links, location='.'):
692692
Users should be using `~self.download_data` instead
693693
694694
"""
695-
if not (os.path.exists('/FTP/') and os.environ['HOME'].split('/')[-1] == 'idies'):
695+
if not os.path.exists('/FTP/'):
696696
raise FileNotFoundError(
697697
'No data archive found. This should be run on Sciserver '
698698
'with the data drive mounted.'
699699
)
700700

701-
if not os.path.exists(location):
702-
os.mkdir(location)
701+
# make sure the output folder exits
702+
os.makedirs(location, exist_ok=True)
703703

704704
for link in links['sciserver']:
705+
link = str(link)
705706
log.info(f'Copying to {link} from the data drive ...')
706707
if not os.path.exists(link):
707708
raise ValueError(
@@ -711,7 +712,8 @@ def _copy_sciserver(self, links, location='.'):
711712
'Heasarc Help desk'
712713
)
713714
if os.path.isdir(link):
714-
shutil.copytree(link, location)
715+
download_dir = os.path.basename(link.strip('/'))
716+
shutil.copytree(link, f'{location}/{download_dir}')
715717
else:
716718
shutil.copy(link, location)
717719

astroquery/heasarc/tests/test_heasarc.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import shutil
55
import pytest
6+
import tempfile
67
from unittest.mock import patch, PropertyMock
78
from astropy.coordinates import SkyCoord
89
from astropy.table import Table
@@ -303,6 +304,22 @@ def test_download_data__missingcolumn(host):
303304
):
304305
Heasarc.download_data(Table({"id": [1]}), host=host)
305306

307+
def test_download_data__sciserver():
308+
with tempfile.TemporaryDirectory() as tmpdir:
309+
datadir = f'{tmpdir}/data'
310+
downloaddir = f'{tmpdir}/download'
311+
os.makedirs(datadir, exist_ok=True)
312+
with open(f'{datadir}/file.txt', 'w') as fp:
313+
fp.write('data')
314+
# include both a file and a directory
315+
tab = Table({'sciserver': [f'{tmpdir}/data/file.txt', f'{tmpdir}/data']})
316+
# The patch is to avoid the test that we are on sciserver
317+
with patch('os.path.exists') as exists:
318+
exists.return_value = True
319+
Heasarc.download_data(tab, host="sciserver", location=downloaddir)
320+
assert os.path.exists(f'{downloaddir}/file.txt')
321+
assert os.path.exists(f'{downloaddir}/data')
322+
assert os.path.exists(f'{downloaddir}/data/file.txt')
306323

307324
def test_download_data__outside_sciserver():
308325
with pytest.raises(
@@ -350,20 +367,20 @@ def test_s3_mock_basic(s3_mock):
350367
def test_s3_mock_file(s3_mock):
351368
links = Table({"aws": [f"s3://{s3_bucket}/{s3_key1}"]})
352369
Heasarc.enable_cloud(profile=False)
353-
Heasarc.download_data(links, host="aws", location=".")
354-
file = s3_key1.split("/")[-1]
355-
assert os.path.exists(file)
356-
os.remove(file)
370+
with tempfile.TemporaryDirectory() as tmpdir:
371+
Heasarc.download_data(links, host="aws", location=tmpdir)
372+
file = s3_key1.split("/")[-1]
373+
assert os.path.exists(f'{tmpdir}/{file}')
357374

358375

359376
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
360377
@pytest.mark.skipif("not DO_AWS_S3")
361378
def test_s3_mock_directory(s3_mock):
362379
links = Table({"aws": [f"s3://{s3_bucket}/{s3_dir}"]})
363380
Heasarc.enable_cloud(profile=False)
364-
Heasarc.download_data(links, host="aws", location=".")
365-
assert os.path.exists("location")
366-
assert os.path.exists("location/file1.txt")
367-
assert os.path.exists("location/sub/file2.txt")
368-
assert os.path.exists("location/sub/sub2/file3.txt")
369-
shutil.rmtree("location")
381+
with tempfile.TemporaryDirectory() as tmpdir:
382+
Heasarc.download_data(links, host="aws", location=tmpdir)
383+
assert os.path.exists(f"{tmpdir}/location")
384+
assert os.path.exists(f"{tmpdir}/location/file1.txt")
385+
assert os.path.exists(f"{tmpdir}/location/sub/file2.txt")
386+
assert os.path.exists(f"{tmpdir}/location/sub/sub2/file3.txt")

0 commit comments

Comments
 (0)