Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ mast

- Switch to use HTTP continuation for partial downloads. [#3448]

- Add ``batch_size`` parameter to ``MastMissions.get_product_list``, ``Observations.get_product_list``,
and ``utils.resolve_object`` to allow controlling the number of items sent in each batch request to the server.
This can help avoid timeouts or connection errors for large requests. [#3454]


Infrastructure, Utility and Other Changes and Additions
-------------------------------------------------------
Expand Down
38 changes: 17 additions & 21 deletions astroquery/mast/missions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,14 @@ def _extract_products(self, response):
list
A list of products extracted from the response.
"""
def normalize_products(products):
"""
Normalize the products list to ensure it is flat and not nested.
"""
combined = []
for resp in response:
products = resp.json().get('products', [])
# Flatten if nested
if products and isinstance(products[0], list):
return products[0]
return products

if isinstance(response, list): # multiple async responses from batching
combined = []
for resp in response:
products = normalize_products(resp.json().get('products', []))
combined.extend(products)
return combined
else: # single response
return normalize_products(response.json().get('products', []))
products = products[0]
combined.extend(products)
return combined

def _parse_result(self, response, *, verbose=False): # Used by the async_to_sync decorator functionality
"""
Expand Down Expand Up @@ -392,7 +384,7 @@ def query_object_async(self, objectname, *, radius=3*u.arcmin, limit=5000, offse
select_cols=select_cols, **criteria)

@class_or_instance
def get_product_list_async(self, datasets):
def get_product_list_async(self, datasets, batch_size=1000):
"""
Given a dataset ID or list of dataset IDs, returns a list of associated data products.

Expand All @@ -403,6 +395,9 @@ def get_product_list_async(self, datasets):
datasets : str, list, `~astropy.table.Row`, `~astropy.table.Column`, `~astropy.table.Table`
Row/Table of MastMissions query results (e.g. output from `query_object`)
or single/list of dataset ID(s).
batch_size : int, optional
Default 1000. Number of dataset IDs to include in each batch request to the server.
If you experience timeouts or connection errors, consider lowering this value.

Returns
-------
Expand All @@ -414,8 +409,9 @@ def get_product_list_async(self, datasets):
if isinstance(datasets, Table) or isinstance(datasets, Row):
dataset_kwd = self.get_dataset_kwd()
if not dataset_kwd:
log.warning('Please input dataset IDs as a string, list of strings, or `~astropy.table.Column`.')
return None
error_msg = (f'Dataset keyword not found for mission "{self.mission}". '
'Please input dataset IDs as a string, list of strings, or `~astropy.table.Column`.')
raise InvalidQueryError(error_msg)

# Extract dataset IDs based on input type and mission
if isinstance(datasets, Table):
Expand All @@ -441,15 +437,15 @@ def get_product_list_async(self, datasets):
results = utils._batched_request(
datasets,
params={},
max_batch=1000,
max_batch=batch_size,
param_key="dataset_ids",
request_func=lambda p: self._service_api_connection.missions_request_async(self.service, p),
extract_func=lambda r: [r], # missions_request_async already returns one result
desc=f"Fetching products for {len(datasets)} unique datasets"
)

# Return a list of responses only if multiple requests were made
return results[0] if len(results) == 1 else results
# Return a list of responses
return results

def get_unique_product_list(self, datasets):
"""
Expand Down
53 changes: 36 additions & 17 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def _filter_ffi_observations(self, observations):
return obs_table[mask]

@class_or_instance
def get_product_list_async(self, observations):
def get_product_list_async(self, observations, batch_size=500):
"""
Given a "Product Group Id" (column name obsid) returns a list of associated data products.
Note that obsid is NOT the same as obs_id, and inputting obs_id values will result in
Expand All @@ -518,31 +518,50 @@ def get_product_list_async(self, observations):
Row/Table of MAST query results (e.g. output from `query_object`)
or single/list of MAST Product Group Id(s) (obsid).
See description `here <https://masttest.stsci.edu/api/v0/_c_a_o_mfields.html>`__.
batch_size : int, optional
Default 500. Number of obsids to include in each batch request to the server.
If you experience timeouts or connection errors, consider lowering this value.

Returns
-------
response : list of `~requests.Response`
A list of asynchronous response objects for each batch request.
"""

# getting the obsid list
# Getting the obsids as a list
if np.isscalar(observations):
observations = np.array([observations])
if isinstance(observations, Table) or isinstance(observations, Row):
observations = [observations]
elif isinstance(observations, (Row, Table)):
# Filter out TESS FFIs and TICA FFIs
# Can only perform filtering on Row or Table because of access to `target_name` field
observations = self._filter_ffi_observations(observations)
observations = observations['obsid']
if isinstance(observations, list):
observations = np.array(observations)

observations = observations[observations != ""]
if observations.size == 0:
raise InvalidQueryError("Observation list is empty, no associated products.")

service = self._caom_products
params = {'obsid': ','.join(observations)}

return self._portal_api_connection.service_request_async(service, params)
observations = observations['obsid'].tolist()

# Clean and validate
observations = [str(obs).strip() for obs in observations]
observations = [obs for obs in observations if obs]
if not observations:
raise InvalidQueryError('Observation list is empty, no associated products.')

# Define a helper to join obsids for each batch request
def _request_joined_obsid(params):
"""Join batched obsid list into comma-separated string and send async request."""
pp = dict(params)
vals = pp.get('obsid', [])
pp['obsid'] = ','.join(map(str, vals))
return self._portal_api_connection.service_request_async(self._caom_products, pp)[0]

# Perform batched requests
results = utils._batched_request(
items=observations,
params={},
max_batch=batch_size,
param_key='obsid',
request_func=_request_joined_obsid,
extract_func=lambda r: [r],
desc=f'Fetching products for {len(observations)} unique observations'
)

return results

def filter_products(self, products, *, mrp_only=False, extension=None, **filters):
"""
Expand Down
17 changes: 13 additions & 4 deletions astroquery/mast/tests/test_mast.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,21 @@ def test_missions_query_criteria(patch_post):
def test_missions_get_product_list_async(patch_post):
# String input
result = mast.MastMissions.get_product_list_async('Z14Z0104T')
assert isinstance(result, MockResponse)
assert isinstance(result, list)

# List input
in_datasets = ['Z14Z0104T', 'Z14Z0102T']
result = mast.MastMissions.get_product_list_async(in_datasets)
assert isinstance(result, MockResponse)
assert isinstance(result, list)

# Row input
datasets = mast.MastMissions.query_object("M101", radius=".002 deg")
result = mast.MastMissions.get_product_list_async(datasets[:3])
assert isinstance(result, MockResponse)
assert isinstance(result, list)

# Table input
result = mast.MastMissions.get_product_list_async(datasets[0])
assert isinstance(result, MockResponse)
assert isinstance(result, list)

# Unsupported data type for datasets
with pytest.raises(TypeError) as err_type:
Expand All @@ -330,6 +330,11 @@ def test_missions_get_product_list_async(patch_post):
mast.MastMissions.get_product_list_async([' '])
assert 'Dataset list is empty' in str(err_empty.value)

# No dataset keyword
with pytest.raises(InvalidQueryError, match='Dataset keyword not found for mission "invalid"'):
missions = mast.MastMissions(mission='invalid')
missions.get_product_list_async(Table({'a': [1, 2, 3]}))


def test_missions_get_product_list(patch_post):
# String input
Expand Down Expand Up @@ -798,6 +803,10 @@ def test_observations_get_product_list(patch_post):
result = mast.Observations.get_product_list(in_obsids)
assert isinstance(result, Table)

# Error if no valid obsids are found
with pytest.raises(InvalidQueryError, match='Observation list is empty'):
mast.Observations.get_product_list([' '])


def test_observations_filter_products(patch_post):
products = mast.Observations.get_product_list('2003738726')
Expand Down
41 changes: 29 additions & 12 deletions astroquery/mast/tests/test_mast_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,25 @@ def test_missions_get_product_list_async(self):

# Table as input
responses = MastMissions.get_product_list_async(datasets[:3])
assert isinstance(responses, Response)
assert isinstance(responses, list)

# Row as input
responses = MastMissions.get_product_list_async(datasets[0])
assert isinstance(responses, Response)
assert isinstance(responses, list)

# String as input
responses = MastMissions.get_product_list_async(datasets[0]['sci_data_set_name'])
assert isinstance(responses, Response)
assert isinstance(responses, list)

# Column as input
responses = MastMissions.get_product_list_async(datasets[:3]['sci_data_set_name'])
assert isinstance(responses, Response)
assert isinstance(responses, list)

# Batching
responses = MastMissions.get_product_list_async(datasets[:4], batch_size=2)
assert isinstance(responses, list)
assert len(responses) == 2
assert isinstance(responses[0], Response)

# Unsupported data type for datasets
with pytest.raises(TypeError) as err_type:
Expand Down Expand Up @@ -248,14 +254,13 @@ def test_missions_get_product_list(self, capsys):
assert isinstance(result, Table)
assert (result['dataset'] == 'IBKH03020').all()

# Test batching by creating a list of 1001 different strings
# This won't return any results, but will test the batching
dataset_list = [f'{i}' for i in range(1001)]
result = MastMissions.get_product_list(dataset_list)
# Test batching
result_batch = MastMissions.get_product_list(datasets[:2], batch_size=1)
out, _ = capsys.readouterr()
assert isinstance(result, Table)
assert len(result) == 0
assert 'Fetching products for 1001 unique datasets in 2 batches' in out
assert isinstance(result_batch, Table)
assert len(result_batch) == len(result_table)
assert set(result_batch['filename']) == set(result_table['filename'])
assert 'Fetching products for 2 unique datasets in 2 batches' in out

def test_missions_get_unique_product_list(self, caplog):
# Check that no rows are filtered out when all products are unique
Expand Down Expand Up @@ -580,7 +585,11 @@ def test_observations_get_product_list_async(self):
responses = Observations.get_product_list_async(observations[0:4])
assert isinstance(responses, list)

def test_observations_get_product_list(self):
# Batching
responses = Observations.get_product_list_async(observations[0:4], batch_size=2)
assert isinstance(responses, list)

def test_observations_get_product_list(self, capsys):
observations = Observations.query_criteria(objectname='M8', obs_collection=['K2', 'IUE'])
test_obs_id = str(observations[0]['obsid'])
mult_obs_ids = str(observations[0]['obsid']) + ',' + str(observations[1]['obsid'])
Expand Down Expand Up @@ -613,6 +622,14 @@ def test_observations_get_product_list(self):
assert len(obs_collection) == 1
assert obs_collection[0] == 'IUE'

# Test batching
result_batch = Observations.get_product_list(observations[:2], batch_size=1)
out, _ = capsys.readouterr()
assert isinstance(result_batch, Table)
assert len(result_batch) == len(result1)
assert set(result_batch['productFilename']) == set(filenames1)
assert 'Fetching products for 2 unique observations in 2 batches' in out

def test_observations_get_product_list_tess_tica(self, caplog):
# Get observations and products with both TESS and TICA FFIs
obs = Observations.query_criteria(target_name=['TESS FFI', 'TICA FFI', '429031146'])
Expand Down
7 changes: 5 additions & 2 deletions astroquery/mast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _batched_request(
return extract_func(resp)


def resolve_object(objectname, *, resolver=None, resolve_all=False):
def resolve_object(objectname, *, resolver=None, resolve_all=False, batch_size=30):
"""
Resolves one or more object names to a position on the sky.

Expand All @@ -164,6 +164,9 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False):
resolve_all : bool, optional
If True, will try to resolve the object name using all available resolvers ("NED", "SIMBAD").
Default is False.
batch_size : int, optional
Default 30. Number of object names to include in each batch request to the server.
If you experience timeouts or connection errors, consider lowering this value.

Returns
-------
Expand Down Expand Up @@ -230,7 +233,7 @@ def resolve_object(objectname, *, resolver=None, resolve_all=False):
results = _batched_request(
object_names,
params,
max_batch=30,
max_batch=batch_size,
param_key="name",
request_func=lambda p: _simple_request("http://mastresolver.stsci.edu/Santa-war/query", p),
extract_func=lambda r: r.json().get("resolvedCoordinate") or [],
Expand Down
Loading