Skip to content

Commit 9bc2079

Browse files
authored
Merge pull request #3385 from snbianco/roman-exp
2 parents 6fccfe9 + 6da9716 commit 9bc2079

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

astroquery/mast/missions.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, *, mission='hst', mast_token=None):
5454
self.dataset_kwds = { # column keywords corresponding to dataset ID
5555
'hst': 'sci_data_set_name',
5656
'jwst': 'fileSetName',
57+
'roman': 'fileSetName',
5758
'classy': 'Target',
5859
'ullyses': 'observation_id'
5960
}
@@ -80,6 +81,37 @@ def mission(self, value):
8081
self._mission = value.lower() # case-insensitive
8182
self._service_api_connection.set_service_params(self.service_dict, f'search/{self.mission}')
8283

84+
def _extract_products(self, response):
85+
"""
86+
Extract products from the response of a `~requests.Response` object.
87+
88+
Parameters
89+
----------
90+
response : `~requests.Response`
91+
The response object containing the products data.
92+
93+
Returns
94+
-------
95+
list
96+
A list of products extracted from the response.
97+
"""
98+
def normalize_products(products):
99+
"""
100+
Normalize the products list to ensure it is flat and not nested.
101+
"""
102+
if products and isinstance(products[0], list):
103+
return products[0]
104+
return products
105+
106+
if isinstance(response, list): # multiple async responses from batching
107+
combined = []
108+
for resp in response:
109+
products = normalize_products(resp.json().get('products', []))
110+
combined.extend(products)
111+
return combined
112+
else: # single response
113+
return normalize_products(response.json().get('products', []))
114+
83115
def _parse_result(self, response, *, verbose=False): # Used by the async_to_sync decorator functionality
84116
"""
85117
Parse the results of a `~requests.Response` objects and return an `~astropy.table.Table` of results.
@@ -105,17 +137,11 @@ def _parse_result(self, response, *, verbose=False): # Used by the async_to_syn
105137
if len(results) >= self.limit:
106138
warnings.warn("Maximum results returned, may not include all sources within radius.",
107139
MaxResultsWarning)
108-
elif self.service == self._list_products:
109-
# Results from post_list_products endpoint need to be handled differently
110-
if isinstance(response, list): # multiple async responses from batching
111-
combined_products = []
112-
for resp in response:
113-
combined_products.extend(resp.json().get('products', []))
114-
return Table(combined_products)
115-
116-
results = Table(response.json()['products']) # single async response
140+
return results
117141

118-
return results
142+
elif self.service == self._list_products:
143+
products = self._extract_products(response)
144+
return Table(products)
119145

120146
def _validate_criteria(self, **criteria):
121147
"""
@@ -537,8 +563,8 @@ def download_file(self, uri, *, local_path=None, cache=True, verbose=True):
537563
"""
538564

539565
# Construct the full data URL based on mission
540-
if self.mission in ['hst', 'jwst']:
541-
# HST and JWST have a dedicated endpoint for retrieving products
566+
if self.mission in ['hst', 'jwst', 'roman']:
567+
# HST, JWST, and RST have a dedicated endpoint for retrieving products
542568
base_url = self._service_api_connection.MISSIONS_DOWNLOAD_URL + self.mission + '/api/v0.1/retrieve_product'
543569
keyword = 'product_name'
544570
else:

astroquery/mast/tests/test_mast_remote.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,18 @@ def check_result(result, path):
360360
@pytest.mark.parametrize("mission, query_params", [
361361
('jwst', {'fileSetName': 'jw01189001001_02101_00001'}),
362362
('classy', {'Target': 'J0021+0052'}),
363-
('ullyses', {'host_galaxy_name': 'WLM', 'select_cols': ['observation_id']})
363+
('ullyses', {'host_galaxy_name': 'WLM', 'select_cols': ['observation_id']}),
364+
('roman', {'program': 3}),
364365
])
365366
def test_missions_workflow(self, tmp_path, mission, query_params):
366367
# Test workflow with other missions
367368
m = MastMissions(mission=mission)
368369

370+
# Roman requires extra setup to point towards the test server
371+
if mission == 'roman':
372+
m._service_api_connection.SERVICE_URL = 'https://masttest.stsci.edu'
373+
m._service_api_connection.REQUEST_URL = 'https://masttest.stsci.edu/search/roman/api/v0.1/'
374+
369375
# Criteria query
370376
datasets = m.query_criteria(**query_params)
371377
assert isinstance(datasets, Table)

0 commit comments

Comments
 (0)