Skip to content

Commit ff37c91

Browse files
committed
Return results in order of input, return as map
1 parent f1713f8 commit ff37c91

File tree

4 files changed

+54
-30
lines changed

4 files changed

+54
-30
lines changed

astroquery/mast/observations.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ def download_products(self, products, *, download_dir=None, flat=False,
874874
return manifest
875875

876876
def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=False, pagesize=None, page=None,
877-
mrp_only=False, extension=None, filter_products={}, **criteria):
877+
mrp_only=False, extension=None, filter_products={}, return_uri_map=False, **criteria):
878878
"""
879879
Given an `~astropy.table.Table` of data products or query criteria and filter parameters,
880880
returns the associated cloud data URIs.
@@ -908,6 +908,10 @@ def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=Fa
908908
or more acceptable values for that parameter.
909909
Filter behavior is AND between the filters and OR within a filter set.
910910
For example: {"productType": "SCIENCE", "extension"=["fits","jpg"]}
911+
return_uri_map : bool, optional
912+
Default False. If set to True, returns a dictionary mapping the original data product
913+
URIs to their corresponding cloud URIs. This is useful for tracking which products were
914+
successfully converted to cloud URIs.
911915
**criteria
912916
Criteria to apply. At least one non-positional criteria must be supplied.
913917
Valid criteria are coordinates, objectname, radius (as in `query_region` and `query_object`),
@@ -951,20 +955,31 @@ def get_cloud_uris(self, data_products=None, *, include_bucket=True, full_url=Fa
951955
# Filter product list
952956
data_products = self.filter_products(data_products, mrp_only=mrp_only, extension=extension,
953957
**filter_products)
958+
data_uris = data_products['dataURI']
954959
else: # data_products is a list of URIs
955960
# Warn if trying to supply filters
956961
if filter_products or extension or mrp_only:
957962
warnings.warn('Filtering is not supported when providing a list of MAST URIs. '
958963
'To apply filters, please provide query criteria or a table of data products '
959964
'as returned by `Observations.get_product_list`', InputWarning)
965+
data_uris = data_products
960966

961-
if not len(data_products):
967+
if not len(data_uris):
962968
warnings.warn('No matching products to fetch associated cloud URIs.', NoResultsWarning)
963969
return
964970

965971
# Remove duplicate products
966-
data_products = utils.remove_duplicate_products(data_products, 'dataURI')
967-
return self._cloud_connection.get_cloud_uri_list(data_products, include_bucket, full_url)
972+
data_uris = utils.remove_duplicate_products(data_uris, 'dataURI')
973+
974+
# Get cloud URIS
975+
cloud_uris = self._cloud_connection.get_cloud_uri_list(data_uris, include_bucket, full_url)
976+
977+
# If return_uri_map is True, create a mapping of dataURIs to cloud URIs
978+
if return_uri_map:
979+
uri_map = dict(zip(data_uris, cloud_uris))
980+
return uri_map
981+
982+
return cloud_uris
968983

969984
def get_cloud_uri(self, data_product, *, include_bucket=True, full_url=False):
970985
"""

astroquery/mast/tests/test_mast.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def test_missions_get_product_list(patch_post):
353353
def test_missions_get_unique_product_list(patch_post, caplog):
354354
unique_products = mast.MastMissions.get_unique_product_list('Z14Z0104T')
355355
assert isinstance(unique_products, Table)
356-
assert (unique_products == unique(unique_products, keys='filename')).all()
356+
assert (len(unique_products) == len(unique(unique_products, keys='filename')))
357357
# No INFO messages should be logged
358358
with caplog.at_level('INFO', logger='astroquery'):
359359
assert caplog.text == ''
@@ -770,6 +770,12 @@ def test_observations_get_cloud_uris(mock_client, patch_post):
770770
assert len(uris) == 1
771771
assert uris[0] == expected
772772

773+
# Return a map of URIs
774+
uri_map = mast.Observations.get_cloud_uris([mast_uri], return_uri_map=True)
775+
assert isinstance(uri_map, dict)
776+
assert len(uri_map) == 1
777+
assert uri_map[mast_uri] == expected
778+
773779
# Warn if attempting to filter with list input
774780
with pytest.warns(InputWarning, match='Filtering is not supported'):
775781
mast.Observations.get_cloud_uris([mast_uri],

astroquery/mast/tests/test_mast_remote.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_missions_get_unique_product_list(self, caplog):
269269
# Unique product list should have fewer rows
270270
assert len(products) > len(unique_products)
271271
# Rows should be unique based on filename
272-
assert (unique_products == unique(unique_products, keys='filename')).all()
272+
assert (len(unique_products) == len(unique(unique_products, keys='filename')))
273273
# Check that INFO messages were logged
274274
with caplog.at_level('INFO', logger='astroquery'):
275275
assert 'products were duplicates' in caplog.text
@@ -644,7 +644,7 @@ def test_observations_get_unique_product_list(self, caplog):
644644
# Unique product list should have fewer rows
645645
assert len(products) > len(unique_products)
646646
# Rows should be unique based on dataURI
647-
assert (unique_products == unique(unique_products, keys='dataURI')).all()
647+
assert (len(unique_products) == len(unique(unique_products, keys='dataURI')))
648648
# Check that INFO messages were logged
649649
with caplog.at_level('INFO', logger='astroquery'):
650650
assert 'products were duplicates' in caplog.text
@@ -878,6 +878,13 @@ def test_observations_get_cloud_uris_list_input(self):
878878
assert len(uris) > 0, f'Products for URI list {uri_list} were not found in the cloud.'
879879
assert uris == expected
880880

881+
# return map of dataURI to cloud URI
882+
uri_map = Observations.get_cloud_uris(uri_list, return_uri_map=True)
883+
assert isinstance(uri_map, dict)
884+
assert len(uri_map) == 2
885+
for i, uri in enumerate(uri_list):
886+
assert uri_map[uri] == expected[i]
887+
881888
# check for warning if filters are provided with list input
882889
with pytest.warns(InputWarning, match='Filtering is not supported'):
883890
Observations.get_cloud_uris(uri_list,

astroquery/mast/utils.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import platform
1414

1515
from astropy.coordinates import SkyCoord
16-
from astropy.table import unique, Table
16+
from astropy.table import Table
1717
from astropy import units as u
1818

1919
from .. import log
@@ -273,9 +273,9 @@ def mast_relative_path(mast_uri):
273273
The associated relative path(s).
274274
"""
275275
if isinstance(mast_uri, str):
276-
uri_list = [("uri", mast_uri)]
277-
else: # mast_uri parameter is a list
278-
uri_list = [("uri", uri) for uri in mast_uri]
276+
uri_list = [mast_uri]
277+
else:
278+
uri_list = list(mast_uri)
279279

280280
# Split the list into chunks of 50 URIs; this is necessary
281281
# to avoid "414 Client Error: Request-URI Too Large".
@@ -284,19 +284,19 @@ def mast_relative_path(mast_uri):
284284
result = []
285285
for chunk in uri_list_chunks:
286286
response = _simple_request("https://mast.stsci.edu/api/v0.1/path_lookup/",
287-
{"uri": [mast_uri[1] for mast_uri in chunk]})
287+
{"uri": [mast_uri for mast_uri in chunk]})
288288

289289
json_response = response.json()
290290

291291
for uri in chunk:
292292
# Chunk is a list of tuples where the tuple is
293293
# ("uri", "/path/to/product")
294294
# so we index for path (index=1)
295-
path = json_response.get(uri[1])["path"]
295+
path = json_response.get(uri)["path"]
296296
if path is None:
297-
warnings.warn(f"Failed to retrieve MAST relative path for {uri[1]}. Skipping...", NoResultsWarning)
298-
continue
299-
if 'galex' in path:
297+
warnings.warn(f"Failed to retrieve MAST relative path for {uri}. Skipping...", NoResultsWarning)
298+
path = None
299+
elif 'galex' in path:
300300
path = path.lstrip("/mast/")
301301
elif '/ps1/' in path:
302302
path = path.replace("/ps1/", "panstarrs/ps1/public/")
@@ -330,20 +330,16 @@ def remove_duplicate_products(data_products, uri_key):
330330
Table containing products with unique dataURIs.
331331
"""
332332
# Get unique products based on input type
333+
seen = set()
333334
if isinstance(data_products, Table):
334-
unique_products = unique(data_products, keys=uri_key)
335-
else: # data_products is a list
336-
seen = set()
337-
unique_products = []
338-
for uri in data_products:
339-
if uri not in seen:
340-
seen.add(uri)
341-
unique_products.append(uri)
342-
343-
number = len(data_products)
344-
number_unique = len(unique_products)
345-
if number_unique < number:
346-
log.info(f"{number - number_unique} of {number} products were duplicates. "
347-
f"Only returning {number_unique} unique product(s).")
335+
unique_rows = [row for row in data_products if not (row[uri_key] in seen or seen.add(row[uri_key]))]
336+
unique_products = type(data_products)(rows=unique_rows)
337+
else: # Assume data_products is a list of URIs
338+
unique_products = [uri for uri in data_products if not (uri in seen or seen.add(uri))]
339+
340+
duplicates_removed = len(data_products) - len(unique_products)
341+
if duplicates_removed > 0:
342+
log.info(f"{duplicates_removed} of {len(data_products)} products were duplicates. "
343+
f"Only returning {len(unique_products)} unique product(s).")
348344

349345
return unique_products

0 commit comments

Comments
 (0)