diff --git a/astroquery/mast/missions.py b/astroquery/mast/missions.py index 184305fa37..bfc0de051d 100644 --- a/astroquery/mast/missions.py +++ b/astroquery/mast/missions.py @@ -54,6 +54,7 @@ def __init__(self, *, mission='hst', mast_token=None): self.dataset_kwds = { # column keywords corresponding to dataset ID 'hst': 'sci_data_set_name', 'jwst': 'fileSetName', + 'roman': 'fileSetName', 'classy': 'Target', 'ullyses': 'observation_id' } @@ -80,6 +81,37 @@ def mission(self, value): self._mission = value.lower() # case-insensitive self._service_api_connection.set_service_params(self.service_dict, f'search/{self.mission}') + def _extract_products(self, response): + """ + Extract products from the response of a `~requests.Response` object. + + Parameters + ---------- + response : `~requests.Response` + The response object containing the products data. + + Returns + ------- + list + A list of products extracted from the response. + """ + def normalize_products(products): + """ + Normalize the products list to ensure it is flat and not nested. + """ + if products and isinstance(products[0], list): + return products[0] + return products + + if isinstance(response, list): # multiple async responses from batching + combined = [] + for resp in response: + products = normalize_products(resp.json().get('products', [])) + combined.extend(products) + return combined + else: # single response + return normalize_products(response.json().get('products', [])) + def _parse_result(self, response, *, verbose=False): # Used by the async_to_sync decorator functionality """ Parse the results of a `~requests.Response` objects and return an `~astropy.table.Table` of results. @@ -105,17 +137,11 @@ def _parse_result(self, response, *, verbose=False): # Used by the async_to_syn if len(results) >= self.limit: warnings.warn("Maximum results returned, may not include all sources within radius.", MaxResultsWarning) - elif self.service == self._list_products: - # Results from post_list_products endpoint need to be handled differently - if isinstance(response, list): # multiple async responses from batching - combined_products = [] - for resp in response: - combined_products.extend(resp.json().get('products', [])) - return Table(combined_products) - - results = Table(response.json()['products']) # single async response + return results - return results + elif self.service == self._list_products: + products = self._extract_products(response) + return Table(products) def _validate_criteria(self, **criteria): """ @@ -536,8 +562,8 @@ def download_file(self, uri, *, local_path=None, cache=True, verbose=True): """ # Construct the full data URL based on mission - if self.mission in ['hst', 'jwst']: - # HST and JWST have a dedicated endpoint for retrieving products + if self.mission in ['hst', 'jwst', 'roman']: + # HST, JWST, and RST have a dedicated endpoint for retrieving products base_url = self._service_api_connection.MISSIONS_DOWNLOAD_URL + self.mission + '/api/v0.1/retrieve_product' keyword = 'product_name' else: diff --git a/astroquery/mast/tests/test_mast_remote.py b/astroquery/mast/tests/test_mast_remote.py index f377f39d07..06ab5cffd3 100644 --- a/astroquery/mast/tests/test_mast_remote.py +++ b/astroquery/mast/tests/test_mast_remote.py @@ -365,12 +365,18 @@ def check_result(result, path): @pytest.mark.parametrize("mission, query_params", [ ('jwst', {'fileSetName': 'jw01189001001_02101_00001'}), ('classy', {'Target': 'J0021+0052'}), - ('ullyses', {'host_galaxy_name': 'WLM', 'select_cols': ['observation_id']}) + ('ullyses', {'host_galaxy_name': 'WLM', 'select_cols': ['observation_id']}), + ('roman', {'program': 3}), ]) def test_missions_workflow(self, tmp_path, mission, query_params): # Test workflow with other missions m = MastMissions(mission=mission) + # Roman requires extra setup to point towards the test server + if mission == 'roman': + m._service_api_connection.SERVICE_URL = 'https://masttest.stsci.edu' + m._service_api_connection.REQUEST_URL = 'https://masttest.stsci.edu/search/roman/api/v0.1/' + # Criteria query datasets = m.query_criteria(**query_params) assert isinstance(datasets, Table)