Skip to content

Commit fae02b8

Browse files
authored
Merge pull request #2511 from jdavies-st/mast-no-subdirs
MAST: Add flat option to Observations.download_products
2 parents 3ba1062 + 279c683 commit fae02b8

File tree

3 files changed

+93
-27
lines changed

3 files changed

+93
-27
lines changed

CHANGES.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ hsa
1616

1717
- New module to access ESA Herschel mission. [#2122]
1818

19-
mast
20-
^^^
21-
22-
- Fixed ``Observations.get_product_list`` to handle input lists of obsids. [#2504]
23-
2419

2520
Service fixes and enhancements
2621
------------------------------
@@ -95,6 +90,11 @@ mast
9590
- Cull duplicate downloads for the same dataURI in ``Observations.download_products()``
9691
and duplicate URIs in ``Observations.get_cloud_uris``. [#2497]
9792

93+
- Fixed ``Observations.get_product_list`` to handle input lists of obsids. [#2504]
94+
95+
- Add a ``flat`` option to ``Observation.download_products()`` to turn off the
96+
automatic creation and organizing of products into subdirectories. [#2511]
97+
9898
oac
9999
^^^
100100

astroquery/mast/observations.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
584584

585585
return status, msg, url
586586

587-
def _download_files(self, products, base_dir, *, cache=True, cloud_only=False,):
587+
def _download_files(self, products, base_dir, *, flat=False, cache=True, cloud_only=False,):
588588
"""
589589
Takes an `~astropy.table.Table` of data products and downloads them into the directory given by base_dir.
590590
@@ -594,6 +594,9 @@ def _download_files(self, products, base_dir, *, cache=True, cloud_only=False,):
594594
Table containing products to be downloaded.
595595
base_dir : str
596596
Directory in which files will be downloaded.
597+
flat : bool
598+
Default is False. If set to True, no subdirectories will be made for the
599+
downloaded files.
597600
cache : bool
598601
Default is True. If file is found on disk it will not be downloaded again.
599602
cloud_only : bool, optional
@@ -610,9 +613,12 @@ def _download_files(self, products, base_dir, *, cache=True, cloud_only=False,):
610613
for data_product in products:
611614

612615
# create the local file download path
613-
local_path = os.path.join(base_dir, data_product['obs_collection'], data_product['obs_id'])
614-
if not os.path.exists(local_path):
615-
os.makedirs(local_path)
616+
if not flat:
617+
local_path = os.path.join(base_dir, data_product['obs_collection'], data_product['obs_id'])
618+
if not os.path.exists(local_path):
619+
os.makedirs(local_path)
620+
else:
621+
local_path = base_dir
616622
local_path = os.path.join(local_path, os.path.basename(data_product['productFilename']))
617623

618624
# download the files
@@ -642,8 +648,8 @@ def _download_curl_script(self, products, out_dir):
642648
"""
643649

644650
url_list = [("uri", url) for url in products['dataURI']]
645-
download_file = "mastDownload_" + time.strftime("%Y%m%d%H%M%S")
646-
local_path = os.path.join(out_dir.rstrip('/'), download_file + ".sh")
651+
download_file = "mastDownload_" + time.strftime("%Y%m%d%H%M%S") + ".sh"
652+
local_path = os.path.join(out_dir, download_file)
647653

648654
response = self._download_file(self._portal_api_connection.MAST_BUNDLE_URL + ".sh",
649655
local_path, data=url_list, method="POST")
@@ -660,7 +666,7 @@ def _download_curl_script(self, products, out_dir):
660666
'Message': [msg]})
661667
return manifest
662668

663-
def download_products(self, products, *, download_dir=None,
669+
def download_products(self, products, *, download_dir=None, flat=False,
664670
cache=True, curl_flag=False, mrp_only=False, cloud_only=False, **filters):
665671
"""
666672
Download data products.
@@ -673,6 +679,14 @@ def download_products(self, products, *, download_dir=None,
673679
or a Table of products (as is returned by `get_product_list`)
674680
download_dir : str, optional
675681
Optional. Directory to download files to. Defaults to current directory.
682+
flat : bool, optional
683+
Default is False. If set to True, and download_dir is specified, it will put
684+
all files into download_dir without subdirectories. Or if set to True and
685+
download_dir is not specified, it will put files in the current directory,
686+
again with no subdirs. The default of False puts files into the standard
687+
directory structure of "mastDownload/<obs_collection>/<obs_id>/". If
688+
curl_flag=True, the flat flag has no effect, as astroquery does not control
689+
how MAST generates the curl download script.
676690
cache : bool, optional
677691
Default is True. If file is found on disc it will not be downloaded again.
678692
Note: has no affect when downloading curl script.
@@ -731,13 +745,19 @@ def download_products(self, products, *, download_dir=None,
731745
download_dir = '.'
732746

733747
if curl_flag: # don't want to download the files now, just the curl script
748+
if flat:
749+
# flat=True doesn't work with curl_flag=True, so issue a warning
750+
warnings.warn("flat=True has no effect on curl downloads.", InputWarning)
734751
manifest = self._download_curl_script(products,
735752
download_dir)
736753

737754
else:
738-
base_dir = download_dir.rstrip('/') + "/mastDownload"
755+
if flat:
756+
base_dir = download_dir
757+
else:
758+
base_dir = os.path.join(download_dir, "mastDownload")
739759
manifest = self._download_files(products,
740-
base_dir=base_dir,
760+
base_dir=base_dir, flat=flat,
741761
cache=cache,
742762
cloud_only=cloud_only)
743763

astroquery/mast/tests/test_mast_remote.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Licensed under a 3-clause BSD style license - see LICENSE.rst
22

3+
from re import sub
4+
from pathlib import Path
35
import numpy as np
46
import os
57
import pytest
@@ -14,13 +16,26 @@
1416
from astroquery import mast
1517

1618
from ..utils import ResolverError
17-
from ...exceptions import (InvalidQueryError, MaxResultsWarning, NoResultsWarning,
19+
from ...exceptions import (InputWarning, InvalidQueryError, MaxResultsWarning, NoResultsWarning,
1820
RemoteServiceError)
1921

2022

2123
OBSID = '1647157'
2224

2325

26+
@pytest.fixture(scope="module")
27+
def msa_product_table():
28+
# Pull products for a JWST NIRSpec MSA observation with 6 known
29+
# duplicates of the MSA configuration file, propID=2736
30+
products = mast.Observations.get_product_list("87602009")
31+
32+
# Filter out everything but the MSA config file
33+
mask = np.char.find(products["dataURI"], "_msa.fits") != -1
34+
products = products[mask]
35+
36+
return products
37+
38+
2439
@pytest.mark.remote_data
2540
class TestMast:
2641

@@ -290,21 +305,42 @@ def test_observations_download_products(self, tmpdir):
290305
assert os.path.isfile(result2['Local Path'][0])
291306
assert len(result2) == 1
292307

293-
def test_observations_download_products_no_duplicates(self, tmpdir, caplog):
308+
def test_observations_download_products_flat(self, tmp_path, msa_product_table):
309+
310+
# Get a product list with 6 duplicate JWST MSA config files
311+
products = msa_product_table
312+
313+
assert len(products) == 6
314+
315+
# Download with flat=True
316+
manifest = mast.Observations.download_products(products, flat=True,
317+
download_dir=tmp_path)
318+
319+
assert Path(manifest["Local Path"][0]).parent == tmp_path
320+
321+
def test_observations_download_products_flat_curl(self, tmp_path, msa_product_table):
322+
323+
# Get a product list with 6 duplicate JWST MSA config files
324+
products = msa_product_table
294325

295-
# Pull products for a JWST NIRSpec MSA observation with 6 known
296-
# duplicates of the MSA configuration file, propID=2736
297-
products = mast.Observations.get_product_list("87602009")
326+
assert len(products) == 6
327+
328+
# Download with flat=True, curl_flag=True, look for warning
329+
with pytest.warns(InputWarning):
330+
mast.Observations.download_products(products, flat=True,
331+
curl_flag=True,
332+
download_dir=tmp_path)
333+
334+
def test_observations_download_products_no_duplicates(self, tmp_path, caplog, msa_product_table):
298335

299-
# Filter out everything but the MSA config file
300-
mask = np.char.find(products["dataURI"], "_msa.fits") != -1
301-
products = products[mask]
336+
# Get a product list with 6 duplicate JWST MSA config files
337+
products = msa_product_table
302338

303339
assert len(products) == 6
304340

305341
# Download the product
306342
manifest = mast.Observations.download_products(products,
307-
download_dir=str(tmpdir))
343+
download_dir=tmp_path)
308344

309345
# Check that it downloads the MSA config file only once
310346
assert len(manifest) == 1
@@ -313,14 +349,24 @@ def test_observations_download_products_no_duplicates(self, tmpdir, caplog):
313349
with caplog.at_level("INFO", logger="astroquery"):
314350
assert "products were duplicates" in caplog.text
315351

352+
def test_observations_get_cloud_uris_no_duplicates(self, msa_product_table):
353+
354+
# Get a product list with 6 duplicate JWST MSA config files
355+
products = msa_product_table
356+
357+
assert len(products) == 6
358+
316359
# enable access to public AWS S3 bucket
317-
mast.Observations.enable_cloud_dataset()
360+
mast.Observations.enable_cloud_dataset(provider='AWS')
318361

319-
# Check duplicate cloud URIs as well
320-
uris = mast.Observations.get_cloud_uris(products)
362+
# Check for cloud URIs. Accept a NoResultsWarning if AWS S3
363+
# doesn't have the file. It doesn't matter as we're only checking
364+
# that the duplicate products have been culled to a single one.
365+
with pytest.warns(NoResultsWarning):
366+
uris = mast.Observations.get_cloud_uris(products)
321367
assert len(uris) == 1
322368

323-
def test_observations_download_file(self, tmpdir):
369+
def test_observations_download_file(self):
324370

325371
# enabling cloud connection
326372
mast.Observations.enable_cloud_dataset(provider='AWS')

0 commit comments

Comments
 (0)