diff --git a/CHANGES.rst b/CHANGES.rst index 9540003431..fc87a2f010 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -30,6 +30,10 @@ mast - Switch to use HTTP continuation for partial downloads. [#3448] +- Add ``batch_size`` parameter to ``MastMissions.get_product_list``, ``Observations.get_product_list``, + and ``utils.resolve_object`` to allow controlling the number of items sent in each batch request to the server. + This can help avoid timeouts or connection errors for large requests. [#3454] + Infrastructure, Utility and Other Changes and Additions ------------------------------------------------------- diff --git a/astroquery/mast/missions.py b/astroquery/mast/missions.py index ba87307d30..fdf2f92881 100644 --- a/astroquery/mast/missions.py +++ b/astroquery/mast/missions.py @@ -94,22 +94,14 @@ def _extract_products(self, response): list A list of products extracted from the response. """ - def normalize_products(products): - """ - Normalize the products list to ensure it is flat and not nested. - """ + combined = [] + for resp in response: + products = resp.json().get('products', []) + # Flatten if nested if products and isinstance(products[0], list): - return products[0] - return products - - if isinstance(response, list): # multiple async responses from batching - combined = [] - for resp in response: - products = normalize_products(resp.json().get('products', [])) - combined.extend(products) - return combined - else: # single response - return normalize_products(response.json().get('products', [])) + products = products[0] + combined.extend(products) + return combined def _parse_result(self, response, *, verbose=False): # Used by the async_to_sync decorator functionality """ @@ -392,7 +384,7 @@ def query_object_async(self, objectname, *, radius=3*u.arcmin, limit=5000, offse select_cols=select_cols, **criteria) @class_or_instance - def get_product_list_async(self, datasets): + def get_product_list_async(self, datasets, batch_size=1000): """ Given a dataset ID or list of dataset IDs, returns a list of associated data products. @@ -403,6 +395,9 @@ def get_product_list_async(self, datasets): datasets : str, list, `~astropy.table.Row`, `~astropy.table.Column`, `~astropy.table.Table` Row/Table of MastMissions query results (e.g. output from `query_object`) or single/list of dataset ID(s). + batch_size : int, optional + Default 1000. Number of dataset IDs to include in each batch request to the server. + If you experience timeouts or connection errors, consider lowering this value. Returns ------- @@ -414,8 +409,9 @@ def get_product_list_async(self, datasets): if isinstance(datasets, Table) or isinstance(datasets, Row): dataset_kwd = self.get_dataset_kwd() if not dataset_kwd: - log.warning('Please input dataset IDs as a string, list of strings, or `~astropy.table.Column`.') - return None + error_msg = (f'Dataset keyword not found for mission "{self.mission}". ' + 'Please input dataset IDs as a string, list of strings, or `~astropy.table.Column`.') + raise InvalidQueryError(error_msg) # Extract dataset IDs based on input type and mission if isinstance(datasets, Table): @@ -441,15 +437,15 @@ def get_product_list_async(self, datasets): results = utils._batched_request( datasets, params={}, - max_batch=1000, + max_batch=batch_size, param_key="dataset_ids", request_func=lambda p: self._service_api_connection.missions_request_async(self.service, p), extract_func=lambda r: [r], # missions_request_async already returns one result desc=f"Fetching products for {len(datasets)} unique datasets" ) - # Return a list of responses only if multiple requests were made - return results[0] if len(results) == 1 else results + # Return a list of responses + return results def get_unique_product_list(self, datasets): """ diff --git a/astroquery/mast/observations.py b/astroquery/mast/observations.py index a515960ff7..6317f70fcc 100644 --- a/astroquery/mast/observations.py +++ b/astroquery/mast/observations.py @@ -504,7 +504,7 @@ def _filter_ffi_observations(self, observations): return obs_table[mask] @class_or_instance - def get_product_list_async(self, observations): + def get_product_list_async(self, observations, batch_size=500): """ Given a "Product Group Id" (column name obsid) returns a list of associated data products. Note that obsid is NOT the same as obs_id, and inputting obs_id values will result in @@ -518,31 +518,50 @@ def get_product_list_async(self, observations): Row/Table of MAST query results (e.g. output from `query_object`) or single/list of MAST Product Group Id(s) (obsid). See description `here `__. + batch_size : int, optional + Default 500. Number of obsids to include in each batch request to the server. + If you experience timeouts or connection errors, consider lowering this value. Returns ------- response : list of `~requests.Response` + A list of asynchronous response objects for each batch request. """ - - # getting the obsid list + # Getting the obsids as a list if np.isscalar(observations): - observations = np.array([observations]) - if isinstance(observations, Table) or isinstance(observations, Row): + observations = [observations] + elif isinstance(observations, (Row, Table)): # Filter out TESS FFIs and TICA FFIs # Can only perform filtering on Row or Table because of access to `target_name` field observations = self._filter_ffi_observations(observations) - observations = observations['obsid'] - if isinstance(observations, list): - observations = np.array(observations) - - observations = observations[observations != ""] - if observations.size == 0: - raise InvalidQueryError("Observation list is empty, no associated products.") - - service = self._caom_products - params = {'obsid': ','.join(observations)} - - return self._portal_api_connection.service_request_async(service, params) + observations = observations['obsid'].tolist() + + # Clean and validate + observations = [str(obs).strip() for obs in observations] + observations = [obs for obs in observations if obs] + if not observations: + raise InvalidQueryError('Observation list is empty, no associated products.') + + # Define a helper to join obsids for each batch request + def _request_joined_obsid(params): + """Join batched obsid list into comma-separated string and send async request.""" + pp = dict(params) + vals = pp.get('obsid', []) + pp['obsid'] = ','.join(map(str, vals)) + return self._portal_api_connection.service_request_async(self._caom_products, pp)[0] + + # Perform batched requests + results = utils._batched_request( + items=observations, + params={}, + max_batch=batch_size, + param_key='obsid', + request_func=_request_joined_obsid, + extract_func=lambda r: [r], + desc=f'Fetching products for {len(observations)} unique observations' + ) + + return results def filter_products(self, products, *, mrp_only=False, extension=None, **filters): """ diff --git a/astroquery/mast/tests/test_mast.py b/astroquery/mast/tests/test_mast.py index 9d90da325e..8558f73c55 100644 --- a/astroquery/mast/tests/test_mast.py +++ b/astroquery/mast/tests/test_mast.py @@ -304,21 +304,21 @@ def test_missions_query_criteria(patch_post): def test_missions_get_product_list_async(patch_post): # String input result = mast.MastMissions.get_product_list_async('Z14Z0104T') - assert isinstance(result, MockResponse) + assert isinstance(result, list) # List input in_datasets = ['Z14Z0104T', 'Z14Z0102T'] result = mast.MastMissions.get_product_list_async(in_datasets) - assert isinstance(result, MockResponse) + assert isinstance(result, list) # Row input datasets = mast.MastMissions.query_object("M101", radius=".002 deg") result = mast.MastMissions.get_product_list_async(datasets[:3]) - assert isinstance(result, MockResponse) + assert isinstance(result, list) # Table input result = mast.MastMissions.get_product_list_async(datasets[0]) - assert isinstance(result, MockResponse) + assert isinstance(result, list) # Unsupported data type for datasets with pytest.raises(TypeError) as err_type: @@ -330,6 +330,11 @@ def test_missions_get_product_list_async(patch_post): mast.MastMissions.get_product_list_async([' ']) assert 'Dataset list is empty' in str(err_empty.value) + # No dataset keyword + with pytest.raises(InvalidQueryError, match='Dataset keyword not found for mission "invalid"'): + missions = mast.MastMissions(mission='invalid') + missions.get_product_list_async(Table({'a': [1, 2, 3]})) + def test_missions_get_product_list(patch_post): # String input @@ -798,6 +803,10 @@ def test_observations_get_product_list(patch_post): result = mast.Observations.get_product_list(in_obsids) assert isinstance(result, Table) + # Error if no valid obsids are found + with pytest.raises(InvalidQueryError, match='Observation list is empty'): + mast.Observations.get_product_list([' ']) + def test_observations_filter_products(patch_post): products = mast.Observations.get_product_list('2003738726') diff --git a/astroquery/mast/tests/test_mast_remote.py b/astroquery/mast/tests/test_mast_remote.py index a4db087e9c..6aa391e292 100644 --- a/astroquery/mast/tests/test_mast_remote.py +++ b/astroquery/mast/tests/test_mast_remote.py @@ -199,19 +199,25 @@ def test_missions_get_product_list_async(self): # Table as input responses = MastMissions.get_product_list_async(datasets[:3]) - assert isinstance(responses, Response) + assert isinstance(responses, list) # Row as input responses = MastMissions.get_product_list_async(datasets[0]) - assert isinstance(responses, Response) + assert isinstance(responses, list) # String as input responses = MastMissions.get_product_list_async(datasets[0]['sci_data_set_name']) - assert isinstance(responses, Response) + assert isinstance(responses, list) # Column as input responses = MastMissions.get_product_list_async(datasets[:3]['sci_data_set_name']) - assert isinstance(responses, Response) + assert isinstance(responses, list) + + # Batching + responses = MastMissions.get_product_list_async(datasets[:4], batch_size=2) + assert isinstance(responses, list) + assert len(responses) == 2 + assert isinstance(responses[0], Response) # Unsupported data type for datasets with pytest.raises(TypeError) as err_type: @@ -248,14 +254,13 @@ def test_missions_get_product_list(self, capsys): assert isinstance(result, Table) assert (result['dataset'] == 'IBKH03020').all() - # Test batching by creating a list of 1001 different strings - # This won't return any results, but will test the batching - dataset_list = [f'{i}' for i in range(1001)] - result = MastMissions.get_product_list(dataset_list) + # Test batching + result_batch = MastMissions.get_product_list(datasets[:2], batch_size=1) out, _ = capsys.readouterr() - assert isinstance(result, Table) - assert len(result) == 0 - assert 'Fetching products for 1001 unique datasets in 2 batches' in out + assert isinstance(result_batch, Table) + assert len(result_batch) == len(result_table) + assert set(result_batch['filename']) == set(result_table['filename']) + assert 'Fetching products for 2 unique datasets in 2 batches' in out def test_missions_get_unique_product_list(self, caplog): # Check that no rows are filtered out when all products are unique @@ -580,7 +585,11 @@ def test_observations_get_product_list_async(self): responses = Observations.get_product_list_async(observations[0:4]) assert isinstance(responses, list) - def test_observations_get_product_list(self): + # Batching + responses = Observations.get_product_list_async(observations[0:4], batch_size=2) + assert isinstance(responses, list) + + def test_observations_get_product_list(self, capsys): observations = Observations.query_criteria(objectname='M8', obs_collection=['K2', 'IUE']) test_obs_id = str(observations[0]['obsid']) mult_obs_ids = str(observations[0]['obsid']) + ',' + str(observations[1]['obsid']) @@ -613,6 +622,14 @@ def test_observations_get_product_list(self): assert len(obs_collection) == 1 assert obs_collection[0] == 'IUE' + # Test batching + result_batch = Observations.get_product_list(observations[:2], batch_size=1) + out, _ = capsys.readouterr() + assert isinstance(result_batch, Table) + assert len(result_batch) == len(result1) + assert set(result_batch['productFilename']) == set(filenames1) + assert 'Fetching products for 2 unique observations in 2 batches' in out + def test_observations_get_product_list_tess_tica(self, caplog): # Get observations and products with both TESS and TICA FFIs obs = Observations.query_criteria(target_name=['TESS FFI', 'TICA FFI', '429031146']) diff --git a/astroquery/mast/utils.py b/astroquery/mast/utils.py index 0b4d666192..aba37af17f 100644 --- a/astroquery/mast/utils.py +++ b/astroquery/mast/utils.py @@ -148,7 +148,7 @@ def _batched_request( return extract_func(resp) -def resolve_object(objectname, *, resolver=None, resolve_all=False): +def resolve_object(objectname, *, resolver=None, resolve_all=False, batch_size=30): """ Resolves one or more object names to a position on the sky. @@ -164,6 +164,9 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False): resolve_all : bool, optional If True, will try to resolve the object name using all available resolvers ("NED", "SIMBAD"). Default is False. + batch_size : int, optional + Default 30. Number of object names to include in each batch request to the server. + If you experience timeouts or connection errors, consider lowering this value. Returns ------- @@ -230,7 +233,7 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False): results = _batched_request( object_names, params, - max_batch=30, + max_batch=batch_size, param_key="name", request_func=lambda p: _simple_request("http://mastresolver.stsci.edu/Santa-war/query", p), extract_func=lambda r: r.json().get("resolvedCoordinate") or [],