Skip to content

Commit 6fccfe9

Browse files
authored
Merge pull request #3365 from snbianco/ASB-31576-filter-integer
2 parents 3a48e0f + dd38fbd commit 6fccfe9

File tree

9 files changed

+256
-99
lines changed

9 files changed

+256
-99
lines changed

CHANGES.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ mast
118118

119119
- Improved ``MastMissions`` queries to accept lists for query critieria values, in addition to comma-delimited strings. [#3319]
120120

121+
- Enhanced ``filter_products`` methods in ``MastMissions`` and ``Observations`` to support advanced filtering expressions
122+
for numeric columns. [#3365]
123+
121124

122125
Infrastructure, Utility and Other Changes and Additions
123126
-------------------------------------------------------

astroquery/mast/missions.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from astroquery.utils import commons, async_to_sync
2323
from astroquery.utils.class_or_instance import class_or_instance
2424
from astropy.utils.console import ProgressBarOrSpinner
25-
from astroquery.exceptions import InvalidQueryError, MaxResultsWarning, InputWarning, NoResultsWarning
25+
from astroquery.exceptions import InvalidQueryError, MaxResultsWarning, NoResultsWarning
2626

2727
from astroquery.mast import utils
2828
from astroquery.mast.core import MastQueryWithLogin
@@ -472,11 +472,19 @@ def filter_products(self, products, *, extension=None, **filters):
472472
extension : string or array, optional
473473
Default is None. Filters by file extension(s), matching any specified extensions.
474474
**filters :
475-
Column-based filters to be applied.
475+
Column-based filters to apply to the products table.
476+
476477
Each keyword corresponds to a column name in the table, with the argument being one or more
477478
acceptable values for that column. AND logic is applied between filters, OR logic within
478-
each filter set.
479-
For example: type="science", extension=["fits","jpg"]
479+
each filter set. For example: type="science", extension=["fits", "jpg"]
480+
481+
For columns with numeric data types (int or float), filter values can be expressed
482+
in several ways:
483+
484+
- A single number: ``size=100``
485+
- A range in the form "start..end": ``size="100..1000"``
486+
- A comparison operator followed by a number: ``size=">=1000"``
487+
- A list of expressions (OR logic): ``size=[100, "500..1000", ">=1500"]``
480488
481489
Returns
482490
-------
@@ -497,17 +505,10 @@ def filter_products(self, products, *, extension=None, **filters):
497505
)
498506
filter_mask &= ext_mask
499507

500-
# Applying column-based filters
501-
for colname, vals in filters.items():
502-
if colname not in products.colnames:
503-
warnings.warn(f"Column '{colname}' not found in product table.", InputWarning)
504-
continue
505-
506-
vals = [vals] if isinstance(vals, str) else vals
507-
col_mask = np.isin(products[colname], vals)
508-
filter_mask &= col_mask
508+
# Apply column-based filters
509+
col_mask = utils.apply_column_filters(products, filters)
510+
filter_mask &= col_mask
509511

510-
# Return filtered products
511512
return products[filter_mask]
512513

513514
def download_file(self, uri, *, local_path=None, cache=True, verbose=True):

astroquery/mast/observations.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def get_product_list_async(self, observations):
545545

546546
def filter_products(self, products, *, mrp_only=False, extension=None, **filters):
547547
"""
548-
Takes an `~astropy.table.Table` of MAST observation data products and filters it based on given filters.
548+
Filters an `~astropy.table.Table` of data products based on given filters.
549549
550550
Parameters
551551
----------
@@ -556,47 +556,50 @@ def filter_products(self, products, *, mrp_only=False, extension=None, **filters
556556
extension : string or array, optional
557557
Default None. Option to filter by file extension.
558558
**filters :
559-
Filters to be applied. Valid filters are all products fields listed
559+
Column-based filters to apply to the products table. Valid filters are all products fields listed
560560
`here <https://masttest.stsci.edu/api/v0/_productsfields.html>`__.
561-
The column name is the keyword, with the argument being one or more acceptable values
562-
for that parameter.
563-
Filter behavior is AND between the filters and OR within a filter set.
564-
For example: productType="SCIENCE",extension=["fits","jpg"]
561+
562+
Each keyword corresponds to a column name in the table, with the argument being one or more
563+
acceptable values for that column. AND logic is applied between filters, OR logic within
564+
each filter set.
565+
566+
For example: type="science", extension=["fits", "jpg"]
567+
568+
For columns with numeric data types (int or float), filter values can be expressed
569+
in several ways:
570+
571+
- A single number: ``size=100``
572+
- A range in the form "start..end": ``size="100..1000"``
573+
- A comparison operator followed by a number: ``size=">=1000"``
574+
- A list of expressions (OR logic): ``size=[100, "500..1000", ">=1500"]``
565575
566576
Returns
567577
-------
568578
response : `~astropy.table.Table`
579+
Filtered table of data products.
569580
"""
570581

571582
filter_mask = np.full(len(products), True, dtype=bool)
572583

573-
# Applying the special filters (mrp_only and extension)
584+
# Filter by minimum recommended products (MRP) if specified
574585
if mrp_only:
575586
filter_mask &= (products['productGroupDescription'] == "Minimum Recommended Products")
576587

588+
# Filter by file extension, if provided
577589
if extension:
578-
if isinstance(extension, str):
579-
extension = [extension]
580-
581-
mask = np.full(len(products), False, dtype=bool)
582-
for elt in extension:
583-
mask |= [False if isinstance(x, np.ma.core.MaskedConstant) else x.endswith(elt)
584-
for x in products["productFilename"]]
585-
filter_mask &= mask
586-
587-
# Applying the rest of the filters
588-
for colname, vals in filters.items():
589-
590-
if isinstance(vals, str):
591-
vals = [vals]
592-
593-
mask = np.full(len(products), False, dtype=bool)
594-
for elt in vals:
595-
mask |= (products[colname] == elt)
596-
597-
filter_mask &= mask
598-
599-
return products[np.where(filter_mask)]
590+
extensions = [extension] if isinstance(extension, str) else extension
591+
ext_mask = np.array(
592+
[not isinstance(x, np.ma.core.MaskedConstant) and any(x.endswith(ext) for ext in extensions)
593+
for x in products["productFilename"]],
594+
dtype=bool
595+
)
596+
filter_mask &= ext_mask
597+
598+
# Apply column-based filters
599+
col_mask = utils.apply_column_filters(products, filters)
600+
filter_mask &= col_mask
601+
602+
return products[filter_mask]
600603

601604
def download_file(self, uri, *, local_path=None, base_url=None, cache=True, cloud_only=False, verbose=True):
602605
"""

astroquery/mast/tests/test_mast.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -367,15 +367,51 @@ def test_missions_get_unique_product_list(patch_post, caplog):
367367
def test_missions_filter_products(patch_post):
368368
# Filter products list by column
369369
products = mast.MastMissions.get_product_list('Z14Z0104T')
370-
filtered = mast.MastMissions.filter_products(products,
371-
category='CALIBRATED')
370+
filtered = mast.MastMissions.filter_products(products, category='CALIBRATED')
372371
assert isinstance(filtered, Table)
373372
assert all(filtered['category'] == 'CALIBRATED')
374373

375-
# Filter by non-existing column
376-
with pytest.warns(InputWarning):
377-
mast.MastMissions.filter_products(products,
378-
invalid=True)
374+
# Filter by extension
375+
filtered = mast.MastMissions.filter_products(products, extension='fits')
376+
assert len(filtered) > 0
377+
378+
# Numeric filtering
379+
# Single integer value
380+
filtered = mast.MastMissions.filter_products(products, size=11520)
381+
assert all(filtered['size'] == 11520)
382+
383+
# Single string value
384+
filtered = mast.MastMissions.filter_products(products, size='11520')
385+
assert all(filtered['size'] == 11520)
386+
387+
# Comparison operators
388+
filtered = mast.MastMissions.filter_products(products, size='<15000')
389+
assert all(filtered['size'] < 15000)
390+
391+
filtered = mast.MastMissions.filter_products(products, size='>15000')
392+
assert all(filtered['size'] > 15000)
393+
394+
filtered = mast.MastMissions.filter_products(products, size='>=14400')
395+
assert all(filtered['size'] >= 14400)
396+
397+
filtered = mast.MastMissions.filter_products(products, size='<=14400')
398+
assert all(filtered['size'] <= 14400)
399+
400+
# Range operator
401+
filtered = mast.MastMissions.filter_products(products, size='14400..17280')
402+
assert all((filtered['size'] >= 14400) & (filtered['size'] <= 17280))
403+
404+
# List of expressions
405+
filtered = mast.MastMissions.filter_products(products, size=[14400, '>20000'])
406+
assert all((filtered['size'] == 14400) | (filtered['size'] > 20000))
407+
408+
with pytest.raises(InvalidQueryError, match="Could not parse numeric filter 'invalid' for column 'size'"):
409+
# Invalid filter value
410+
mast.MastMissions.filter_products(products, size='invalid')
411+
412+
# Error when filtering by non-existing column
413+
with pytest.raises(InvalidQueryError, match="Column 'non_existing' not found in product table."):
414+
mast.MastMissions.filter_products(products, non_existing='value')
379415

380416

381417
def test_missions_download_products(patch_post, tmp_path):
@@ -671,11 +707,31 @@ def test_observations_get_product_list(patch_post):
671707

672708
def test_observations_filter_products(patch_post):
673709
products = mast.Observations.get_product_list('2003738726')
674-
result = mast.Observations.filter_products(products,
675-
productType=["SCIENCE"],
676-
mrp_only=False)
677-
assert isinstance(result, Table)
678-
assert len(result) == 7
710+
filtered = mast.Observations.filter_products(products,
711+
productType=["SCIENCE"],
712+
mrp_only=False)
713+
assert isinstance(filtered, Table)
714+
assert len(filtered) == 7
715+
716+
# Filter for minimum recommended products
717+
filtered = mast.Observations.filter_products(products, mrp_only=True)
718+
assert all(filtered['productGroupDescription'] == 'Minimum Recommended Products')
719+
720+
# Filter by extension
721+
filtered = mast.Observations.filter_products(products, extension='fits')
722+
assert len(filtered) > 0
723+
724+
# Numeric filtering
725+
filtered = mast.Observations.filter_products(products, size='<50000')
726+
assert all(filtered['size'] < 50000)
727+
728+
# Numeric filter that cannot be parsed
729+
with pytest.raises(InvalidQueryError, match="Could not parse numeric filter 'invalid' for column 'size'"):
730+
filtered = mast.Observations.filter_products(products, size='invalid')
731+
732+
# Filter by non-existing column
733+
with pytest.raises(InvalidQueryError, match="Column 'invalid' not found in product table."):
734+
mast.Observations.filter_products(products, invalid=True)
679735

680736

681737
def test_observations_download_products(patch_post, tmpdir):
@@ -703,8 +759,7 @@ def test_observations_download_products(patch_post, tmpdir):
703759

704760
# passing row product
705761
products = mast.Observations.get_product_list('2003738726')
706-
result1 = mast.Observations.download_products(products[0],
707-
download_dir=str(tmpdir))
762+
result1 = mast.Observations.download_products(products[0], download_dir=str(tmpdir))
708763
assert isinstance(result1, Table)
709764

710765

astroquery/mast/tests/test_mast_remote.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,6 @@ def test_missions_filter_products(self):
297297
assert isinstance(filtered, Table)
298298
assert all(filtered['category'] == 'CALIBRATED')
299299

300-
# Filter by non-existing column
301-
with pytest.warns(InputWarning):
302-
filtered = MastMissions.filter_products(products,
303-
invalid=True)
304-
305300
def test_missions_download_products(self, tmp_path):
306301
def check_filepath(path):
307302
assert path.is_file()

astroquery/mast/utils.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
Miscellaneous functions used throughout the MAST module.
77
"""
88

9+
import re
910
import warnings
10-
import numpy as np
1111

12+
import numpy as np
1213
import requests
1314
import platform
14-
1515
from astropy.coordinates import SkyCoord
1616
from astropy.table import Table
1717
from astropy import units as u
@@ -345,3 +345,92 @@ def remove_duplicate_products(data_products, uri_key):
345345
f"Only returning {len(unique_products)} unique product(s).")
346346

347347
return unique_products
348+
349+
350+
def parse_numeric_product_filter(val):
351+
"""
352+
Parses a numeric product filter value and returns a function that can be used to filter
353+
a column of a product table.
354+
355+
Parameters
356+
----------
357+
val : str or list of str
358+
The filter value(s). Each entry can be:
359+
- A single number (e.g., "100")
360+
- A range in the form "start..end" (e.g., "100..200")
361+
- A comparison operator followed by a number (e.g., ">=10", "<5", ">100.5")
362+
363+
Returns
364+
-------
365+
response : function
366+
A function that takes a column of a product table and returns a boolean mask indicating
367+
which rows satisfy the filter condition.
368+
"""
369+
# Regular expression to match range patterns
370+
range_pattern = re.compile(r'[+-]?(\d+(\.\d*)?|\.\d+)\.\.[+-]?(\d+(\.\d*)?|\.\d+)')
371+
372+
def single_condition(cond):
373+
"""Helper function to create a condition function for a single value."""
374+
if isinstance(cond, (int, float)):
375+
return lambda col: col == float(cond)
376+
if cond.startswith('>='):
377+
return lambda col: col >= float(cond[2:])
378+
elif cond.startswith('<='):
379+
return lambda col: col <= float(cond[2:])
380+
elif cond.startswith('>'):
381+
return lambda col: col > float(cond[1:])
382+
elif cond.startswith('<'):
383+
return lambda col: col < float(cond[1:])
384+
elif range_pattern.fullmatch(cond):
385+
start, end = map(float, cond.split('..'))
386+
return lambda col: (col >= start) & (col <= end)
387+
else:
388+
return lambda col: col == float(cond)
389+
390+
if isinstance(val, list):
391+
# If val is a list, create a condition for each value and combine them with logical OR
392+
conditions = [single_condition(v) for v in val]
393+
return lambda col: np.logical_or.reduce([cond(col) for cond in conditions])
394+
else:
395+
return single_condition(val)
396+
397+
398+
def apply_column_filters(products, filters):
399+
"""
400+
Applies column-based filters to a product table.
401+
402+
Parameters
403+
----------
404+
products : `~astropy.table.Table`
405+
The product table to filter.
406+
filters : dict
407+
A dictionary where keys are column names and values are the filter values.
408+
409+
Returns
410+
-------
411+
col_mask : `numpy.ndarray`
412+
A boolean mask indicating which rows of the product table satisfy the filter conditions.
413+
"""
414+
col_mask = np.ones(len(products), dtype=bool) # Start with all True mask
415+
416+
# Applying column-based filters
417+
for colname, vals in filters.items():
418+
if colname not in products.colnames:
419+
raise InvalidQueryError(f"Column '{colname}' not found in product table.")
420+
421+
col_data = products[colname]
422+
# If the column is an integer or float, accept numeric filters
423+
if col_data.dtype.kind in ['i', 'f']: # 'i' for integer, 'f' for float
424+
try:
425+
this_mask = parse_numeric_product_filter(vals)(col_data)
426+
except ValueError:
427+
raise InvalidQueryError(f"Could not parse numeric filter '{vals}' for column '{colname}'.")
428+
else: # Assume string or list filter
429+
if isinstance(vals, str):
430+
vals = [vals]
431+
this_mask = np.isin(col_data, vals)
432+
433+
# Combine the current column mask with the overall mask
434+
col_mask &= this_mask
435+
436+
return col_mask

docs/mast/mast_cut.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ To access sector information for a particular coordinate, object, or moving targ
140140
-------------- ------ ------ ---
141141
tess-s0008-1-1 8 1 1
142142
tess-s0034-1-2 34 1 2
143+
tess-s0061-1-2 61 1 2
144+
tess-s0088-1-2 88 1 2
143145

144146
The following example will request SPOC cutouts using the objectname argument, rather
145147
than a set of coordinates.
@@ -167,6 +169,7 @@ The following example requests SPOC cutouts for a moving target.
167169
tess-s0029-1-4 29 1 4
168170
tess-s0043-3-3 43 3 3
169171
tess-s0044-2-4 44 2 4
172+
tess-s0092-4-3 92 4 3
170173

171174

172175
Zcut

0 commit comments

Comments
 (0)