Skip to content

Commit 8e450ca

Browse files
authored
Merge pull request #3314 from snbianco/return-uri-map
Enhancements for mast.Observations.get_cloud_uris()
2 parents f1713f8 + 823e49f commit 8e450ca

File tree

7 files changed

+103
-101
lines changed

7 files changed

+103
-101
lines changed

CHANGES.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ mast
6767
- Added ``resolve_all`` parameter to ``MastClass.resolve_object`` to resolve object names and return
6868
coordinates for all available resolvers. [#3292]
6969

70+
- Fix bug in ``utils.remove_duplicate_products`` that does not retain the order of the products in an input table. [#3314]
71+
72+
- Added ``return_uri_map`` parameter to ``Observations.get_cloud_uris`` to return a mapping of the input data product URIs
73+
to the returned cloud URIs. [#3314]
74+
75+
- Added ``verbose`` parameter to ``Observations.get_cloud_uris`` to control whether warnings are logged when a product cannot
76+
be found in the cloud. [#3314]
77+
7078

7179
Infrastructure, Utility and Other Changes and Additions
7280
-------------------------------------------------------

astroquery/mast/cloud.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def get_cloud_uri(self, data_product, include_bucket=True, full_url=False):
117117
# Output from ``get_cloud_uri_list`` is always a list even when it's only 1 URI
118118
return uri_list[0]
119119

120-
def get_cloud_uri_list(self, data_products, include_bucket=True, full_url=False):
120+
def get_cloud_uri_list(self, data_products, *, include_bucket=True, full_url=False, verbose=True):
121121
"""
122122
Takes an `~astropy.table.Table` of data products and returns the associated cloud data uris.
123123
@@ -132,6 +132,8 @@ def get_cloud_uri_list(self, data_products, include_bucket=True, full_url=False)
132132
full_url : bool
133133
Default False. Return an HTTP fetchable url instead of a cloud uri.
134134
Must set include_bucket to False to use this option.
135+
verbose : bool
136+
Default True. Whether to issue warnings if a product cannot be found in the cloud.
135137
136138
Returns
137139
-------
@@ -141,7 +143,7 @@ def get_cloud_uri_list(self, data_products, include_bucket=True, full_url=False)
141143
"""
142144
s3_client = self.boto3.client('s3', config=self.config)
143145
data_uris = data_products if isinstance(data_products, list) else data_products['dataURI']
144-
paths = utils.mast_relative_path(data_uris)
146+
paths = utils.mast_relative_path(data_uris, verbose=verbose)
145147
if isinstance(paths, str): # Handle the case where only one product was requested
146148
paths = [paths]
147149

@@ -164,7 +166,8 @@ def get_cloud_uri_list(self, data_products, include_bucket=True, full_url=False)
164166
except self.botocore.exceptions.ClientError as e:
165167
if e.response['Error']['Code'] != "404":
166168
raise
167-
warnings.warn("Unable to locate file {}.".format(path), NoResultsWarning)
169+
if verbose:
170+
warnings.warn("Unable to locate file {}.".format(path), NoResultsWarning)
168171
uri_list.append(None)
169172

170173
return uri_list

astroquery/mast/observations.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,8 @@ def download_products(self, products, *, download_dir=None, flat=False,
874874
return manifest
875875

876876
def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=False, pagesize=None, page=None,
877-
mrp_only=False, extension=None, filter_products={}, **criteria):
877+
mrp_only=False, extension=None, filter_products={}, return_uri_map=False, verbose=True,
878+
**criteria):
878879
"""
879880
Given an `~astropy.table.Table` of data products or query criteria and filter parameters,
880881
returns the associated cloud data URIs.
@@ -908,6 +909,12 @@ def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=Fa
908909
or more acceptable values for that parameter.
909910
Filter behavior is AND between the filters and OR within a filter set.
910911
For example: {"productType": "SCIENCE", "extension"=["fits","jpg"]}
912+
return_uri_map : bool, optional
913+
Default False. If set to True, returns a dictionary mapping the original data product
914+
URIs to their corresponding cloud URIs. This is useful for tracking which products were
915+
successfully converted to cloud URIs.
916+
verbose : bool, optional
917+
Default True. Whether to issue warnings if a product cannot be found in the cloud.
911918
**criteria
912919
Criteria to apply. At least one non-positional criteria must be supplied.
913920
Valid criteria are coordinates, objectname, radius (as in `query_region` and `query_object`),
@@ -951,20 +958,37 @@ def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=Fa
951958
# Filter product list
952959
data_products = self.filter_products(data_products, mrp_only=mrp_only, extension=extension,
953960
**filter_products)
961+
data_uris = data_products['dataURI']
954962
else: # data_products is a list of URIs
955963
# Warn if trying to supply filters
956964
if filter_products or extension or mrp_only:
957965
warnings.warn('Filtering is not supported when providing a list of MAST URIs. '
958966
'To apply filters, please provide query criteria or a table of data products '
959967
'as returned by `Observations.get_product_list`', InputWarning)
968+
data_uris = data_products
960969

961-
if not len(data_products):
970+
if not len(data_uris):
962971
warnings.warn('No matching products to fetch associated cloud URIs.', NoResultsWarning)
963972
return
964973

965974
# Remove duplicate products
966-
data_products = utils.remove_duplicate_products(data_products, 'dataURI')
967-
return self._cloud_connection.get_cloud_uri_list(data_products, include_bucket, full_url)
975+
data_uris = utils.remove_duplicate_products(data_uris, 'dataURI')
976+
977+
# Get cloud URIS
978+
cloud_uris = self._cloud_connection.get_cloud_uri_list(data_uris,
979+
include_bucket=include_bucket,
980+
full_url=full_url,
981+
verbose=verbose)
982+
983+
# If return_uri_map is True, create a mapping of dataURIs to cloud URIs
984+
if return_uri_map:
985+
uri_map = dict(zip(data_uris, cloud_uris))
986+
return uri_map
987+
988+
# Remove None values from the list
989+
cloud_uris = [uri for uri in cloud_uris if uri is not None]
990+
991+
return cloud_uris
968992

969993
def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False):
970994
"""

astroquery/mast/tests/test_mast.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def test_missions_get_product_list(patch_post):
353353
def test_missions_get_unique_product_list(patch_post, caplog):
354354
unique_products = mast.MastMissions.get_unique_product_list('Z14Z0104T')
355355
assert isinstance(unique_products, Table)
356-
assert (unique_products == unique(unique_products, keys='filename')).all()
356+
assert (len(unique_products) == len(unique(unique_products, keys='filename')))
357357
# No INFO messages should be logged
358358
with caplog.at_level('INFO', logger='astroquery'):
359359
assert caplog.text == ''
@@ -770,6 +770,12 @@ def test_observations_get_cloud_uris(mock_client, patch_post):
770770
assert len(uris) == 1
771771
assert uris[0] == expected
772772

773+
# Return a map of URIs
774+
uri_map = mast.Observations.get_cloud_uris([mast_uri], return_uri_map=True)
775+
assert isinstance(uri_map, dict)
776+
assert len(uri_map) == 1
777+
assert uri_map[mast_uri] == expected
778+
773779
# Warn if attempting to filter with list input
774780
with pytest.warns(InputWarning, match='Filtering is not supported'):
775781
mast.Observations.get_cloud_uris([mast_uri],

astroquery/mast/tests/test_mast_remote.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_missions_get_unique_product_list(self, caplog):
269269
# Unique product list should have fewer rows
270270
assert len(products) > len(unique_products)
271271
# Rows should be unique based on filename
272-
assert (unique_products == unique(unique_products, keys='filename')).all()
272+
assert (len(unique_products) == len(unique(unique_products, keys='filename')))
273273
# Check that INFO messages were logged
274274
with caplog.at_level('INFO', logger='astroquery'):
275275
assert 'products were duplicates' in caplog.text
@@ -570,15 +570,15 @@ def test_observations_get_product_list_async(self):
570570
responses = Observations.get_product_list_async(test_obs[2:3])
571571
assert isinstance(responses, list)
572572

573-
observations = Observations.query_object("M8", radius=".02 deg")
573+
observations = Observations.query_criteria(objectname="M8", obs_collection=["K2", "IUE"])
574574
responses = Observations.get_product_list_async(observations[0])
575575
assert isinstance(responses, list)
576576

577577
responses = Observations.get_product_list_async(observations[0:4])
578578
assert isinstance(responses, list)
579579

580580
def test_observations_get_product_list(self):
581-
observations = Observations.query_object("M8", radius=".04 deg")
581+
observations = Observations.query_criteria(objectname='M8', obs_collection=['K2', 'IUE'])
582582
test_obs_id = str(observations[0]['obsid'])
583583
mult_obs_ids = str(observations[0]['obsid']) + ',' + str(observations[1]['obsid'])
584584

@@ -598,7 +598,7 @@ def test_observations_get_product_list(self):
598598
assert len(result1) == len(result2)
599599
assert set(filenames1) == set(filenames2)
600600

601-
obsLoc = np.where(observations["obs_id"] == 'ktwo200071160-c92_lc')
601+
obsLoc = np.where(observations['obs_id'] == 'ktwo200071160-c92_lc')
602602
result = Observations.get_product_list(observations[obsLoc])
603603
assert isinstance(result, Table)
604604
assert len(result) == 1
@@ -644,7 +644,7 @@ def test_observations_get_unique_product_list(self, caplog):
644644
# Unique product list should have fewer rows
645645
assert len(products) > len(unique_products)
646646
# Rows should be unique based on dataURI
647-
assert (unique_products == unique(unique_products, keys='dataURI')).all()
647+
assert (len(unique_products) == len(unique(unique_products, keys='dataURI')))
648648
# Check that INFO messages were logged
649649
with caplog.at_level('INFO', logger='astroquery'):
650650
assert 'products were duplicates' in caplog.text
@@ -878,6 +878,13 @@ def test_observations_get_cloud_uris_list_input(self):
878878
assert len(uris) > 0, f'Products for URI list {uri_list} were not found in the cloud.'
879879
assert uris == expected
880880

881+
# return map of dataURI to cloud URI
882+
uri_map = Observations.get_cloud_uris(uri_list, return_uri_map=True)
883+
assert isinstance(uri_map, dict)
884+
assert len(uri_map) == 2
885+
for i, uri in enumerate(uri_list):
886+
assert uri_map[uri] == expected[i]
887+
881888
# check for warning if filters are provided with list input
882889
with pytest.warns(InputWarning, match='Filtering is not supported'):
883890
Observations.get_cloud_uris(uri_list,

astroquery/mast/utils.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import platform
1414

1515
from astropy.coordinates import SkyCoord
16-
from astropy.table import unique, Table
16+
from astropy.table import Table
1717
from astropy import units as u
1818

1919
from .. import log
@@ -258,24 +258,26 @@ def split_list_into_chunks(input_list, chunk_size):
258258
yield input_list[idx:idx + chunk_size]
259259

260260

261-
def mast_relative_path(mast_uri):
261+
def mast_relative_path(mast_uri, *, verbose=True):
262262
"""
263263
Given one or more MAST dataURI(s), return the associated relative path(s).
264264
265265
Parameters
266266
----------
267267
mast_uri : str, list of str
268268
The MAST uri(s).
269+
verbose : bool, optional
270+
Default True. Whether to issue warnings if the MAST relative path cannot be found for a product.
269271
270272
Returns
271273
-------
272274
response : str, list of str
273275
The associated relative path(s).
274276
"""
275277
if isinstance(mast_uri, str):
276-
uri_list = [("uri", mast_uri)]
277-
else: # mast_uri parameter is a list
278-
uri_list = [("uri", uri) for uri in mast_uri]
278+
uri_list = [mast_uri]
279+
else:
280+
uri_list = list(mast_uri)
279281

280282
# Split the list into chunks of 50 URIs; this is necessary
281283
# to avoid "414 Client Error: Request-URI Too Large".
@@ -284,19 +286,19 @@ def mast_relative_path(mast_uri):
284286
result = []
285287
for chunk in uri_list_chunks:
286288
response = _simple_request("https://mast.stsci.edu/api/v0.1/path_lookup/",
287-
{"uri": [mast_uri[1] for mast_uri in chunk]})
289+
{"uri": [mast_uri for mast_uri in chunk]})
288290

289291
json_response = response.json()
290292

291293
for uri in chunk:
292294
# Chunk is a list of tuples where the tuple is
293295
# ("uri", "/path/to/product")
294296
# so we index for path (index=1)
295-
path = json_response.get(uri[1])["path"]
297+
path = json_response.get(uri)["path"]
296298
if path is None:
297-
warnings.warn(f"Failed to retrieve MAST relative path for {uri[1]}. Skipping...", NoResultsWarning)
298-
continue
299-
if 'galex' in path:
299+
if verbose:
300+
warnings.warn(f"Failed to retrieve MAST relative path for {uri}. Skipping...", NoResultsWarning)
301+
elif 'galex' in path:
300302
path = path.lstrip("/mast/")
301303
elif '/ps1/' in path:
302304
path = path.replace("/ps1/", "panstarrs/ps1/public/")
@@ -331,19 +333,15 @@ def remove_duplicate_products(data_products, uri_key):
331333
"""
332334
# Get unique products based on input type
333335
if isinstance(data_products, Table):
334-
unique_products = unique(data_products, keys=uri_key)
335-
else: # data_products is a list
336+
_, unique_indices = np.unique(data_products[uri_key], return_index=True)
337+
unique_products = data_products[np.sort(unique_indices)]
338+
else: # list of URIs
336339
seen = set()
337-
unique_products = []
338-
for uri in data_products:
339-
if uri not in seen:
340-
seen.add(uri)
341-
unique_products.append(uri)
342-
343-
number = len(data_products)
344-
number_unique = len(unique_products)
345-
if number_unique < number:
346-
log.info(f"{number - number_unique} of {number} products were duplicates. "
347-
f"Only returning {number_unique} unique product(s).")
340+
unique_products = [uri for uri in data_products if not (uri in seen or seen.add(uri))]
341+
342+
duplicates_removed = len(data_products) - len(unique_products)
343+
if duplicates_removed > 0:
344+
log.info(f"{duplicates_removed} of {len(data_products)} products were duplicates. "
345+
f"Only returning {len(unique_products)} unique product(s).")
348346

349347
return unique_products

0 commit comments

Comments
 (0)