Skip to content

Commit 85d1868

Browse files
authored
Merge pull request #3230 from snbianco/ASB-30644-missions-limit
Add batching to MastMissions.get_product_list
2 parents e0baf0d + 6c3b03f commit 85d1868

File tree

5 files changed

+87
-24
lines changed

5 files changed

+87
-24
lines changed

CHANGES.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ mast
7575

7676
- Corrected parameter checking in ``MastMissions`` to ensure case-sensitive comparisons. [#3260]
7777

78+
- Add batching to ``MastMissions.get_product_list`` to avoid server errors and allow for a larger number of input datasets. [#3230]
79+
80+
7881
simbad
7982
^^^^^^
8083

astroquery/mast/missions.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from astroquery import log
2222
from astroquery.utils import commons, async_to_sync
2323
from astroquery.utils.class_or_instance import class_or_instance
24+
from astropy.utils.console import ProgressBarOrSpinner
2425
from astroquery.exceptions import InvalidQueryError, MaxResultsWarning, InputWarning, NoResultsWarning
2526

2627
from astroquery.mast import utils
@@ -106,7 +107,13 @@ def _parse_result(self, response, *, verbose=False): # Used by the async_to_syn
106107
MaxResultsWarning)
107108
elif self.service == self._list_products:
108109
# Results from post_list_products endpoint need to be handled differently
109-
results = Table(response.json()['products'])
110+
if isinstance(response, list): # multiple async responses from batching
111+
combined_products = []
112+
for resp in response:
113+
combined_products.extend(resp.json().get('products', []))
114+
return Table(combined_products)
115+
116+
results = Table(response.json()['products']) # single async response
110117

111118
return results
112119

@@ -370,13 +377,38 @@ def get_product_list_async(self, datasets):
370377
'list of strings, Astropy Row, Astropy Column, or Astropy Table.')
371378

372379
# Filter out empty strings from IDs
373-
datasets = [item.strip() for item in datasets if item.strip() != '' and item is not None]
374-
if not len(datasets):
380+
datasets = [item.strip() for item in datasets if item and item.strip()]
381+
if not datasets:
375382
raise InvalidQueryError("Dataset list is empty, no associated products.")
376383

377-
# Send async service request
378-
params = {'dataset_ids': datasets}
379-
return self._service_api_connection.missions_request_async(self.service, params)
384+
# Filter out duplicates
385+
datasets = list(set(datasets))
386+
387+
# Batch API calls if number of datasets exceeds maximum
388+
max_batch = 1000
389+
num_datasets = len(datasets)
390+
if num_datasets > max_batch:
391+
# Split datasets into chunks
392+
dataset_chunks = list(utils.split_list_into_chunks(datasets, max_batch))
393+
394+
results = [] # list to store responses from each batch
395+
with ProgressBarOrSpinner(num_datasets, f'Fetching products for {num_datasets} unique datasets '
396+
f'in {len(dataset_chunks)} batches ...') as pb:
397+
datasets_fetched = 0
398+
pb.update(0)
399+
for chunk in dataset_chunks:
400+
# Send request for each chunk and add response to list
401+
params = {'dataset_ids': chunk}
402+
results.append(self._service_api_connection.missions_request_async(self.service, params))
403+
404+
# Update progress bar with the number of datasets that have had products fetched
405+
datasets_fetched += len(chunk)
406+
pb.update(datasets_fetched)
407+
return results
408+
else:
409+
# Single batch request
410+
params = {'dataset_ids': datasets}
411+
return self._service_api_connection.missions_request_async(self.service, params)
380412

381413
def get_unique_product_list(self, datasets):
382414
"""

astroquery/mast/tests/test_mast.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,11 @@ def test_missions_get_product_list(patch_post):
344344
result = mast.MastMissions.get_product_list(datasets[0])
345345
assert isinstance(result, Table)
346346

347+
# Batching
348+
dataset_list = [f'{i}' for i in range(1001)]
349+
result = mast.MastMissions.get_product_list(dataset_list)
350+
assert isinstance(result, Table)
351+
347352

348353
def test_missions_get_unique_product_list(patch_post, caplog):
349354
unique_products = mast.MastMissions.get_unique_product_list('Z14Z0104T')

astroquery/mast/tests/test_mast_remote.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,31 +188,40 @@ def test_missions_get_product_list_async(self):
188188
MastMissions.get_product_list_async([' '])
189189
assert 'Dataset list is empty' in str(err_empty.value)
190190

191-
def test_missions_get_product_list(self):
191+
def test_missions_get_product_list(self, capsys):
192192
datasets = MastMissions.query_object("M4", radius=0.1)
193193
test_dataset = datasets[0]['sci_data_set_name']
194194
multi_dataset = list(datasets[:2]['sci_data_set_name'])
195195

196196
# Compare Row input and string input
197-
result1 = MastMissions.get_product_list(test_dataset)
198-
result2 = MastMissions.get_product_list(datasets[0])
199-
assert isinstance(result1, Table)
200-
assert len(result1) == len(result2)
201-
assert set(result1['filename']) == set(result2['filename'])
197+
result_str = MastMissions.get_product_list(test_dataset)
198+
result_row = MastMissions.get_product_list(datasets[0])
199+
assert isinstance(result_str, Table)
200+
assert len(result_str) == len(result_row)
201+
assert set(result_str['filename']) == set(result_row['filename'])
202202

203203
# Compare Table input and list input
204-
result1 = MastMissions.get_product_list(multi_dataset)
205-
result2 = MastMissions.get_product_list(datasets[:2])
206-
assert isinstance(result1, Table)
207-
assert len(result1) == len(result2)
208-
assert set(result1['filename']) == set(result2['filename'])
204+
result_list = MastMissions.get_product_list(multi_dataset)
205+
result_table = MastMissions.get_product_list(datasets[:2])
206+
assert isinstance(result_list, Table)
207+
assert len(result_list) == len(result_table)
208+
assert set(result_list['filename']) == set(result_table['filename'])
209209

210210
# Filter datasets based on sci_data_set_name and verify products
211211
filtered = datasets[datasets['sci_data_set_name'] == 'IBKH03020']
212212
result = MastMissions.get_product_list(filtered)
213213
assert isinstance(result, Table)
214214
assert (result['dataset'] == 'IBKH03020').all()
215215

216+
# Test batching by creating a list of 1001 different strings
217+
# This won't return any results, but will test the batching
218+
dataset_list = [f'{i}' for i in range(1001)]
219+
result = MastMissions.get_product_list(dataset_list)
220+
out, _ = capsys.readouterr()
221+
assert isinstance(result, Table)
222+
assert len(result) == 0
223+
assert 'Fetching products for 1001 unique datasets in 2 batches' in out
224+
216225
def test_missions_get_unique_product_list(self, caplog):
217226
# Check that no rows are filtered out when all products are unique
218227
dataset_ids = ['JBTAA8010']

astroquery/mast/utils.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,26 @@ def parse_input_location(coordinates=None, objectname=None):
159159
return obj_coord
160160

161161

162+
def split_list_into_chunks(input_list, chunk_size):
163+
"""
164+
Splits a list into chunks of a specified size.
165+
166+
Parameters
167+
----------
168+
input_list : list
169+
List to be split into chunks.
170+
chunk_size : int
171+
Size of each chunk.
172+
173+
Yields
174+
------
175+
chunk : list
176+
A chunk of the input list.
177+
"""
178+
for idx in range(0, len(input_list), chunk_size):
179+
yield input_list[idx:idx + chunk_size]
180+
181+
162182
def mast_relative_path(mast_uri):
163183
"""
164184
Given one or more MAST dataURI(s), return the associated relative path(s).
@@ -180,7 +200,7 @@ def mast_relative_path(mast_uri):
180200

181201
# Split the list into chunks of 50 URIs; this is necessary
182202
# to avoid "414 Client Error: Request-URI Too Large".
183-
uri_list_chunks = list(_split_list_into_chunks(uri_list, chunk_size=50))
203+
uri_list_chunks = list(split_list_into_chunks(uri_list, chunk_size=50))
184204

185205
result = []
186206
for chunk in uri_list_chunks:
@@ -214,12 +234,6 @@ def mast_relative_path(mast_uri):
214234
return result
215235

216236

217-
def _split_list_into_chunks(input_list, chunk_size):
218-
"""Helper function for `mast_relative_path`."""
219-
for idx in range(0, len(input_list), chunk_size):
220-
yield input_list[idx:idx + chunk_size]
221-
222-
223237
def remove_duplicate_products(data_products, uri_key):
224238
"""
225239
Removes duplicate data products that have the same data URI.

0 commit comments

Comments
 (0)