Skip to content

Commit 9a7ba7e

Browse files
committed
Accept MAST URIs as input to get_cloud_uris, add jwst to supported cloud missions
1 parent 1e8c4dd commit 9a7ba7e

File tree

5 files changed

+93
-46
lines changed

5 files changed

+93
-46
lines changed

astroquery/mast/cloud.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(self, provider="AWS", profile=None, verbose=False):
5252
import boto3
5353
import botocore
5454

55-
self.supported_missions = ["mast:hst/product", "mast:tess/product", "mast:kepler", "mast:galex", "mast:ps1"]
55+
self.supported_missions = ["mast:hst/product", "mast:tess/product", "mast:kepler", "mast:galex", "mast:ps1",
56+
"mast:jwst/product"]
5657

5758
self.boto3 = boto3
5859
self.botocore = botocore
@@ -77,11 +78,7 @@ def is_supported(self, data_product):
7778
response : bool
7879
Is the product from a supported mission.
7980
"""
80-
81-
for mission in self.supported_missions:
82-
if data_product['dataURI'].lower().startswith(mission):
83-
return True
84-
return False
81+
return any(data_product['dataURI'].lower().startswith(mission) for mission in self.supported_missions)
8582

8683
def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
8784
"""
@@ -92,7 +89,7 @@ def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
9289
9390
Parameters
9491
----------
95-
data_product : `~astropy.table.Row`
92+
data_product : `~astropy.table.Row`, str
9693
Product to be converted into cloud data uri.
9794
include_bucket : bool
9895
Default True. When false returns the path of the file relative to the
@@ -108,6 +105,8 @@ def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
108105
Cloud URI generated from the data product. If the product cannot be
109106
found in the cloud, None is returned.
110107
"""
108+
# If data_product is a string, convert to a list
109+
data_product = [data_product] if isinstance(data_product, str) else data_product
111110

112111
uri_list = self.get_cloud_uri_list(data_product, include_bucket=include_bucket, full_url=full_url)
113112

@@ -124,8 +123,8 @@ def get_cloud_uri_list(self, data_products, include_bucket=True, full_url=False)
124123
125124
Parameters
126125
----------
127-
data_products : `~astropy.table.Table`
128-
Table containing products to be converted into cloud data uris.
126+
data_products : `~astropy.table.Table`, list
127+
Table containing products or list of MAST uris to be converted into cloud data uris.
129128
include_bucket : bool
130129
Default True. When false returns the path of the file relative to the
131130
top level cloud storage location.
@@ -141,8 +140,8 @@ def get_cloud_uri_list(self, data_products, include_bucket=True, full_url=False)
141140
if data_products includes products not found in the cloud.
142141
"""
143142
s3_client = self.boto3.client('s3', config=self.config)
144-
145-
paths = utils.mast_relative_path(data_products["dataURI"])
143+
data_uris = data_products if isinstance(data_products, list) else data_products['dataURI']
144+
paths = utils.mast_relative_path(data_uris)
146145
if isinstance(paths, str): # Handle the case where only one product was requested
147146
paths = [paths]
148147

astroquery/mast/missions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,15 @@ def _parse_result(self, response, *, verbose=False): # Used by the async_to_syn
9999

100100
if self.service == self._search:
101101
results = self._service_api_connection._parse_result(response, verbose, data_key='results')
102+
103+
# Warn if maximum results are returned
104+
if len(results) >= self.limit:
105+
warnings.warn("Maximum results returned, may not include all sources within radius.",
106+
MaxResultsWarning)
102107
elif self.service == self._list_products:
103108
# Results from post_list_products endpoint need to be handled differently
104109
results = Table(response.json()['products'])
105110

106-
if len(results) >= self.limit:
107-
warnings.warn("Maximum results returned, may not include all sources within radius.",
108-
MaxResultsWarning)
109-
110111
return results
111112

112113
def _validate_criteria(self, **criteria):

astroquery/mast/observations.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -854,9 +854,9 @@ def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=Fa
854854
855855
Parameters
856856
----------
857-
data_products : `~astropy.table.Table`
858-
Table containing products to be converted into cloud data uris. If provided, this will supercede
859-
page_size, page, or any keyword arguments passed in as criteria.
857+
data_products : `~astropy.table.Table`, list
858+
Table containing products or list of MAST uris to be converted into cloud data uris.
859+
If provided, this will supercede page_size, page, or any keyword arguments passed in as criteria.
860860
include_bucket : bool
861861
Default True. When False, returns the path of the file relative to the
862862
top level cloud storage location.
@@ -920,16 +920,23 @@ def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=Fa
920920
# Return list of associated data products
921921
data_products = self.get_product_list(obs)
922922

923-
# Filter product list
924-
data_products = self.filter_products(data_products, mrp_only=mrp_only, extension=extension, **filter_products)
923+
if isinstance(data_products, Table):
924+
# Filter product list
925+
data_products = self.filter_products(data_products, mrp_only=mrp_only, extension=extension,
926+
**filter_products)
927+
else: # data_products is a list of URIs
928+
# Warn if trying to supply filters
929+
if filter_products or extension or mrp_only:
930+
warnings.warn('Filtering is not supported when providing a list of MAST URIs. '
931+
'To apply filters, please provide query criteria or a table of data products '
932+
'as returned by `Observations.get_product_list`', InputWarning)
925933

926934
if not len(data_products):
927-
warnings.warn("No matching products to fetch associated cloud URIs.", NoResultsWarning)
935+
warnings.warn('No matching products to fetch associated cloud URIs.', NoResultsWarning)
928936
return
929937

930938
# Remove duplicate products
931939
data_products = utils.remove_duplicate_products(data_products, 'dataURI')
932-
933940
return self._cloud_connection.get_cloud_uri_list(data_products, include_bucket, full_url)
934941

935942
def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False):
@@ -941,7 +948,7 @@ def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False):
941948
942949
Parameters
943950
----------
944-
data_product : `~astropy.table.Row`
951+
data_product : `~astropy.table.Row`, str
945952
Product to be converted into cloud data uri.
946953
include_bucket : bool
947954
Default True. When false returns the path of the file relative to the

astroquery/mast/tests/test_mast_remote.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -692,20 +692,6 @@ def test_observations_download_products_no_duplicates(self, tmp_path, caplog, ms
692692
with caplog.at_level("INFO", logger="astroquery"):
693693
assert "products were duplicates" in caplog.text
694694

695-
def test_observations_get_cloud_uris_no_duplicates(self, msa_product_table):
696-
697-
# Get a product list with 6 duplicate JWST MSA config files
698-
products = msa_product_table
699-
700-
assert len(products) == 6
701-
702-
# enable access to public AWS S3 bucket
703-
Observations.enable_cloud_dataset(provider='AWS')
704-
705-
# Check that only one URI is returned
706-
uris = Observations.get_cloud_uris(products)
707-
assert len(uris) == 1
708-
709695
def test_observations_download_file(self, tmp_path):
710696

711697
def check_result(result, path):
@@ -776,7 +762,7 @@ def test_observations_download_file_escaped(self, tmp_path):
776762
"s3://stpubdata/panstarrs/ps1/public/rings.v3.skycell/1334/061/"
777763
"rings.v3.skycell.1334.061.stk.r.unconv.exp.fits")
778764
])
779-
def test_get_cloud_uri(self, test_data_uri, expected_cloud_uri):
765+
def test_observations_get_cloud_uri(self, test_data_uri, expected_cloud_uri):
780766
pytest.importorskip("boto3")
781767
# get a product list
782768
product = Table()
@@ -790,13 +776,17 @@ def test_get_cloud_uri(self, test_data_uri, expected_cloud_uri):
790776
assert len(uri) > 0, f'Product for dataURI {test_data_uri} was not found in the cloud.'
791777
assert uri == expected_cloud_uri, f'Cloud URI does not match expected. ({uri} != {expected_cloud_uri})'
792778

779+
# pass the URI as a string
780+
uri = Observations.get_cloud_uri(test_data_uri)
781+
assert uri == expected_cloud_uri, f'Cloud URI does not match expected. ({uri} != {expected_cloud_uri})'
782+
793783
@pytest.mark.parametrize("test_obs_id", ["25568122", "31411", "107604081"])
794-
def test_get_cloud_uris(self, test_obs_id):
784+
def test_observations_get_cloud_uris(self, test_obs_id):
795785
pytest.importorskip("boto3")
796786

797787
# get a product list
798788
index = 24 if test_obs_id == '25568122' else 0
799-
products = Observations.get_product_list(test_obs_id)[index:]
789+
products = Observations.get_product_list(test_obs_id)[index:index + 2]
800790

801791
assert len(products) > 0, (f'No products found for OBSID {test_obs_id}. '
802792
'Unable to move forward with getting URIs from the cloud.')
@@ -814,7 +804,28 @@ def test_get_cloud_uris(self, test_obs_id):
814804
Observations.get_cloud_uris(products,
815805
extension='png')
816806

817-
def test_get_cloud_uris_query(self):
807+
def test_observations_get_cloud_uris_list_input(self):
808+
uri_list = ['mast:HST/product/u24r0102t_c1f.fits',
809+
'mast:PS1/product/rings.v3.skycell.1334.061.stk.r.unconv.exp.fits']
810+
expected = ['s3://stpubdata/hst/public/u24r/u24r0102t/u24r0102t_c1f.fits',
811+
's3://stpubdata/panstarrs/ps1/public/rings.v3.skycell/1334/061/rings.v3.skycell.1334.'
812+
'061.stk.r.unconv.exp.fits']
813+
814+
# list of URI strings as input
815+
uris = Observations.get_cloud_uris(uri_list)
816+
assert len(uris) > 0, f'Products for URI list {uri_list} were not found in the cloud.'
817+
assert uris == expected
818+
819+
# check for warning if filters are provided with list input
820+
with pytest.warns(InputWarning, match='Filtering is not supported'):
821+
Observations.get_cloud_uris(uri_list,
822+
extension='png')
823+
824+
# check for warning if one of the URIs is not found
825+
with pytest.warns(NoResultsWarning, match='Failed to retrieve MAST relative path'):
826+
Observations.get_cloud_uris(['mast:HST/product/does_not_exist.fits'])
827+
828+
def test_observations_get_cloud_uris_query(self):
818829
pytest.importorskip("boto3")
819830

820831
# enable access to public AWS S3 bucket
@@ -839,6 +850,19 @@ def test_get_cloud_uris_query(self):
839850
with pytest.warns(NoResultsWarning):
840851
Observations.get_cloud_uris(target_name=234295611)
841852

853+
def test_observations_get_cloud_uris_no_duplicates(self, msa_product_table):
854+
# Get a product list with 6 duplicate JWST MSA config files
855+
products = msa_product_table
856+
857+
assert len(products) == 6
858+
859+
# enable access to public AWS S3 bucket
860+
Observations.enable_cloud_dataset(provider='AWS')
861+
862+
# Check that only one URI is returned
863+
uris = Observations.get_cloud_uris(products)
864+
assert len(uris) == 1
865+
842866
######################
843867
# CatalogClass tests #
844868
######################

astroquery/mast/utils.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Miscellaneous functions used throughout the MAST module.
77
"""
88

9+
import warnings
910
import numpy as np
1011

1112
import requests
@@ -14,11 +15,11 @@
1415
from urllib import parse
1516

1617
import astropy.coordinates as coord
17-
from astropy.table import unique
18+
from astropy.table import unique, Table
1819

1920
from .. import log
2021
from ..version import version
21-
from ..exceptions import ResolverError, InvalidQueryError
22+
from ..exceptions import NoResultsWarning, ResolverError, InvalidQueryError
2223
from ..utils import commons
2324

2425
from . import conf
@@ -192,6 +193,9 @@ def mast_relative_path(mast_uri):
192193
# ("uri", "/path/to/product")
193194
# so we index for path (index=1)
194195
path = json_response.get(uri[1])["path"]
196+
if path is None:
197+
warnings.warn(f"Failed to retrieve MAST relative path for {uri[1]}. Skipping...", NoResultsWarning)
198+
continue
195199
if 'galex' in path:
196200
path = path.lstrip("/mast/")
197201
elif '/ps1/' in path:
@@ -218,19 +222,31 @@ def _split_list_into_chunks(input_list, chunk_size):
218222
def remove_duplicate_products(data_products, uri_key):
219223
"""
220224
Removes duplicate data products that have the same data URI.
225+
221226
Parameters
222227
----------
223-
data_products : `~astropy.table.Table`
224-
Table containing products to be checked for duplicates.
228+
data_products : `~astropy.table.Table`, list
229+
Table containing products or list of URIs to be checked for duplicates.
225230
uri_key : str
226231
Column name representing the URI of a product.
232+
227233
Returns
228234
-------
229235
unique_products : `~astropy.table.Table`
230236
Table containing products with unique dataURIs.
231237
"""
238+
# Get unique products based on input type
239+
if isinstance(data_products, Table):
240+
unique_products = unique(data_products, keys=uri_key)
241+
else: # data_products is a list
242+
seen = set()
243+
unique_products = []
244+
for uri in data_products:
245+
if uri not in seen:
246+
seen.add(uri)
247+
unique_products.append(uri)
248+
232249
number = len(data_products)
233-
unique_products = unique(data_products, keys=uri_key)
234250
number_unique = len(unique_products)
235251
if number_unique < number:
236252
log.info(f"{number - number_unique} of {number} products were duplicates. "

0 commit comments

Comments
 (0)