|
1 | 1 | # Licensed under a 3-clause BSD style license - see LICENSE.rst |
2 | 2 |
|
3 | 3 | import os |
4 | | -import shutil |
5 | 4 | import pytest |
| 5 | +import tempfile |
6 | 6 | from unittest.mock import patch, PropertyMock |
7 | 7 | from astropy.coordinates import SkyCoord |
8 | 8 | from astropy.table import Table |
@@ -304,6 +304,24 @@ def test_download_data__missingcolumn(host): |
304 | 304 | Heasarc.download_data(Table({"id": [1]}), host=host) |
305 | 305 |
|
306 | 306 |
|
| 307 | +def test_download_data__sciserver(): |
| 308 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 309 | + datadir = f'{tmpdir}/data' |
| 310 | + downloaddir = f'{tmpdir}/download' |
| 311 | + os.makedirs(datadir, exist_ok=True) |
| 312 | + with open(f'{datadir}/file.txt', 'w') as fp: |
| 313 | + fp.write('data') |
| 314 | + # include both a file and a directory |
| 315 | + tab = Table({'sciserver': [f'{tmpdir}/data/file.txt', f'{tmpdir}/data']}) |
| 316 | + # The patch is to avoid the test that we are on sciserver |
| 317 | + with patch('os.path.exists') as exists: |
| 318 | + exists.return_value = True |
| 319 | + Heasarc.download_data(tab, host="sciserver", location=downloaddir) |
| 320 | + assert os.path.exists(f'{downloaddir}/file.txt') |
| 321 | + assert os.path.exists(f'{downloaddir}/data') |
| 322 | + assert os.path.exists(f'{downloaddir}/data/file.txt') |
| 323 | + |
| 324 | + |
307 | 325 | def test_download_data__outside_sciserver(): |
308 | 326 | with pytest.raises( |
309 | 327 | FileNotFoundError, |
@@ -350,20 +368,20 @@ def test_s3_mock_basic(s3_mock): |
350 | 368 | def test_s3_mock_file(s3_mock): |
351 | 369 | links = Table({"aws": [f"s3://{s3_bucket}/{s3_key1}"]}) |
352 | 370 | Heasarc.enable_cloud(profile=False) |
353 | | - Heasarc.download_data(links, host="aws", location=".") |
354 | | - file = s3_key1.split("/")[-1] |
355 | | - assert os.path.exists(file) |
356 | | - os.remove(file) |
| 371 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 372 | + Heasarc.download_data(links, host="aws", location=tmpdir) |
| 373 | + file = s3_key1.split("/")[-1] |
| 374 | + assert os.path.exists(f'{tmpdir}/{file}') |
357 | 375 |
|
358 | 376 |
|
359 | 377 | @pytest.mark.filterwarnings("ignore::DeprecationWarning") |
360 | 378 | @pytest.mark.skipif("not DO_AWS_S3") |
361 | 379 | def test_s3_mock_directory(s3_mock): |
362 | 380 | links = Table({"aws": [f"s3://{s3_bucket}/{s3_dir}"]}) |
363 | 381 | Heasarc.enable_cloud(profile=False) |
364 | | - Heasarc.download_data(links, host="aws", location=".") |
365 | | - assert os.path.exists("location") |
366 | | - assert os.path.exists("location/file1.txt") |
367 | | - assert os.path.exists("location/sub/file2.txt") |
368 | | - assert os.path.exists("location/sub/sub2/file3.txt") |
369 | | - shutil.rmtree("location") |
| 382 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 383 | + Heasarc.download_data(links, host="aws", location=tmpdir) |
| 384 | + assert os.path.exists(f"{tmpdir}/location") |
| 385 | + assert os.path.exists(f"{tmpdir}/location/file1.txt") |
| 386 | + assert os.path.exists(f"{tmpdir}/location/sub/file2.txt") |
| 387 | + assert os.path.exists(f"{tmpdir}/location/sub/sub2/file3.txt") |
0 commit comments