Skip to content

Support expressions when filtering products by numeric columns #3365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ mast

- Improved ``MastMissions`` queries to accept lists for query critieria values, in addition to comma-delimited strings. [#3319]

- Enhanced ``filter_products`` methods in ``MastMissions`` and ``Observations`` to support advanced filtering expressions
for numeric columns. [#3365]


Infrastructure, Utility and Other Changes and Additions
-------------------------------------------------------
Expand Down
29 changes: 15 additions & 14 deletions astroquery/mast/missions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from astroquery.utils import commons, async_to_sync
from astroquery.utils.class_or_instance import class_or_instance
from astropy.utils.console import ProgressBarOrSpinner
from astroquery.exceptions import InvalidQueryError, MaxResultsWarning, InputWarning, NoResultsWarning
from astroquery.exceptions import InvalidQueryError, MaxResultsWarning, NoResultsWarning

from astroquery.mast import utils
from astroquery.mast.core import MastQueryWithLogin
Expand Down Expand Up @@ -472,11 +472,19 @@ def filter_products(self, products, *, extension=None, **filters):
extension : string or array, optional
Default is None. Filters by file extension(s), matching any specified extensions.
**filters :
Column-based filters to be applied.
Column-based filters to apply to the products table.

Each keyword corresponds to a column name in the table, with the argument being one or more
acceptable values for that column. AND logic is applied between filters, OR logic within
each filter set.
For example: type="science", extension=["fits","jpg"]
each filter set. For example: type="science", extension=["fits", "jpg"]

For columns with numeric data types (int or float), filter values can be expressed
in several ways:

- A single number: ``size=100``
- A range in the form "start..end": ``size="100..1000"``
- A comparison operator followed by a number: ``size=">=1000"``
- A list of expressions (OR logic): ``size=[100, "500..1000", ">=1500"]``

Returns
-------
Expand All @@ -497,17 +505,10 @@ def filter_products(self, products, *, extension=None, **filters):
)
filter_mask &= ext_mask

# Applying column-based filters
for colname, vals in filters.items():
if colname not in products.colnames:
warnings.warn(f"Column '{colname}' not found in product table.", InputWarning)
continue

vals = [vals] if isinstance(vals, str) else vals
col_mask = np.isin(products[colname], vals)
filter_mask &= col_mask
# Apply column-based filters
col_mask = utils.apply_column_filters(products, filters)
filter_mask &= col_mask

# Return filtered products
return products[filter_mask]

def download_file(self, uri, *, local_path=None, cache=True, verbose=True):
Expand Down
61 changes: 32 additions & 29 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def get_product_list_async(self, observations):

def filter_products(self, products, *, mrp_only=False, extension=None, **filters):
"""
Takes an `~astropy.table.Table` of MAST observation data products and filters it based on given filters.
Filters an `~astropy.table.Table` of data products based on given filters.

Parameters
----------
Expand All @@ -556,47 +556,50 @@ def filter_products(self, products, *, mrp_only=False, extension=None, **filters
extension : string or array, optional
Default None. Option to filter by file extension.
**filters :
Filters to be applied. Valid filters are all products fields listed
Column-based filters to apply to the products table. Valid filters are all products fields listed
`here <https://masttest.stsci.edu/api/v0/_productsfields.html>`__.
The column name is the keyword, with the argument being one or more acceptable values
for that parameter.
Filter behavior is AND between the filters and OR within a filter set.
For example: productType="SCIENCE",extension=["fits","jpg"]

Each keyword corresponds to a column name in the table, with the argument being one or more
acceptable values for that column. AND logic is applied between filters, OR logic within
each filter set.

For example: type="science", extension=["fits", "jpg"]

For columns with numeric data types (int or float), filter values can be expressed
in several ways:

- A single number: ``size=100``
- A range in the form "start..end": ``size="100..1000"``
- A comparison operator followed by a number: ``size=">=1000"``
- A list of expressions (OR logic): ``size=[100, "500..1000", ">=1500"]``

Returns
-------
response : `~astropy.table.Table`
Filtered table of data products.
"""

filter_mask = np.full(len(products), True, dtype=bool)

# Applying the special filters (mrp_only and extension)
# Filter by minimum recommended products (MRP) if specified
if mrp_only:
filter_mask &= (products['productGroupDescription'] == "Minimum Recommended Products")

# Filter by file extension, if provided
if extension:
if isinstance(extension, str):
extension = [extension]

mask = np.full(len(products), False, dtype=bool)
for elt in extension:
mask |= [False if isinstance(x, np.ma.core.MaskedConstant) else x.endswith(elt)
for x in products["productFilename"]]
filter_mask &= mask

# Applying the rest of the filters
for colname, vals in filters.items():

if isinstance(vals, str):
vals = [vals]

mask = np.full(len(products), False, dtype=bool)
for elt in vals:
mask |= (products[colname] == elt)

filter_mask &= mask

return products[np.where(filter_mask)]
extensions = [extension] if isinstance(extension, str) else extension
ext_mask = np.array(
[not isinstance(x, np.ma.core.MaskedConstant) and any(x.endswith(ext) for ext in extensions)
for x in products["productFilename"]],
dtype=bool
)
filter_mask &= ext_mask

# Apply column-based filters
col_mask = utils.apply_column_filters(products, filters)
filter_mask &= col_mask

return products[filter_mask]

def download_file(self, uri, *, local_path=None, base_url=None, cache=True, cloud_only=False, verbose=True):
"""
Expand Down
81 changes: 68 additions & 13 deletions astroquery/mast/tests/test_mast.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,51 @@ def test_missions_get_unique_product_list(patch_post, caplog):
def test_missions_filter_products(patch_post):
# Filter products list by column
products = mast.MastMissions.get_product_list('Z14Z0104T')
filtered = mast.MastMissions.filter_products(products,
category='CALIBRATED')
filtered = mast.MastMissions.filter_products(products, category='CALIBRATED')
assert isinstance(filtered, Table)
assert all(filtered['category'] == 'CALIBRATED')

# Filter by non-existing column
with pytest.warns(InputWarning):
mast.MastMissions.filter_products(products,
invalid=True)
# Filter by extension
filtered = mast.MastMissions.filter_products(products, extension='fits')
assert len(filtered) > 0

# Numeric filtering
# Single integer value
filtered = mast.MastMissions.filter_products(products, size=11520)
assert all(filtered['size'] == 11520)

# Single string value
filtered = mast.MastMissions.filter_products(products, size='11520')
assert all(filtered['size'] == 11520)

# Comparison operators
filtered = mast.MastMissions.filter_products(products, size='<15000')
assert all(filtered['size'] < 15000)

filtered = mast.MastMissions.filter_products(products, size='>15000')
assert all(filtered['size'] > 15000)

filtered = mast.MastMissions.filter_products(products, size='>=14400')
assert all(filtered['size'] >= 14400)

filtered = mast.MastMissions.filter_products(products, size='<=14400')
assert all(filtered['size'] <= 14400)

# Range operator
filtered = mast.MastMissions.filter_products(products, size='14400..17280')
assert all((filtered['size'] >= 14400) & (filtered['size'] <= 17280))

# List of expressions
filtered = mast.MastMissions.filter_products(products, size=[14400, '>20000'])
assert all((filtered['size'] == 14400) | (filtered['size'] > 20000))

with pytest.raises(InvalidQueryError, match="Could not parse numeric filter 'invalid' for column 'size'"):
# Invalid filter value
mast.MastMissions.filter_products(products, size='invalid')

# Error when filtering by non-existing column
with pytest.raises(InvalidQueryError, match="Column 'non_existing' not found in product table."):
mast.MastMissions.filter_products(products, non_existing='value')


def test_missions_download_products(patch_post, tmp_path):
Expand Down Expand Up @@ -670,11 +706,31 @@ def test_observations_get_product_list(patch_post):

def test_observations_filter_products(patch_post):
products = mast.Observations.get_product_list('2003738726')
result = mast.Observations.filter_products(products,
productType=["SCIENCE"],
mrp_only=False)
assert isinstance(result, Table)
assert len(result) == 7
filtered = mast.Observations.filter_products(products,
productType=["SCIENCE"],
mrp_only=False)
assert isinstance(filtered, Table)
assert len(filtered) == 7

# Filter for minimum recommended products
filtered = mast.Observations.filter_products(products, mrp_only=True)
assert all(filtered['productGroupDescription'] == 'Minimum Recommended Products')

# Filter by extension
filtered = mast.Observations.filter_products(products, extension='fits')
assert len(filtered) > 0

# Numeric filtering
filtered = mast.Observations.filter_products(products, size='<50000')
assert all(filtered['size'] < 50000)

# Numeric filter that cannot be parsed
with pytest.raises(InvalidQueryError, match="Could not parse numeric filter 'invalid' for column 'size'"):
filtered = mast.Observations.filter_products(products, size='invalid')

# Filter by non-existing column
with pytest.raises(InvalidQueryError, match="Column 'invalid' not found in product table."):
mast.Observations.filter_products(products, invalid=True)


def test_observations_download_products(patch_post, tmpdir):
Expand Down Expand Up @@ -702,8 +758,7 @@ def test_observations_download_products(patch_post, tmpdir):

# passing row product
products = mast.Observations.get_product_list('2003738726')
result1 = mast.Observations.download_products(products[0],
download_dir=str(tmpdir))
result1 = mast.Observations.download_products(products[0], download_dir=str(tmpdir))
assert isinstance(result1, Table)


Expand Down
5 changes: 0 additions & 5 deletions astroquery/mast/tests/test_mast_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,6 @@ def test_missions_filter_products(self):
assert isinstance(filtered, Table)
assert all(filtered['category'] == 'CALIBRATED')

# Filter by non-existing column
with pytest.warns(InputWarning):
filtered = MastMissions.filter_products(products,
invalid=True)

def test_missions_download_products(self, tmp_path):
def check_filepath(path):
assert path.is_file()
Expand Down
93 changes: 91 additions & 2 deletions astroquery/mast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
Miscellaneous functions used throughout the MAST module.
"""

import re
import warnings
import numpy as np

import numpy as np
import requests
import platform

from astropy.coordinates import SkyCoord
from astropy.table import Table
from astropy import units as u
Expand Down Expand Up @@ -345,3 +345,92 @@ def remove_duplicate_products(data_products, uri_key):
f"Only returning {len(unique_products)} unique product(s).")

return unique_products


def parse_numeric_product_filter(val):
"""
Parses a numeric product filter value and returns a function that can be used to filter
a column of a product table.

Parameters
----------
val : str or list of str
The filter value(s). Each entry can be:
- A single number (e.g., "100")
- A range in the form "start..end" (e.g., "100..200")
- A comparison operator followed by a number (e.g., ">=10", "<5", ">100.5")

Returns
-------
response : function
A function that takes a column of a product table and returns a boolean mask indicating
which rows satisfy the filter condition.
"""
# Regular expression to match range patterns
range_pattern = re.compile(r'[+-]?(\d+(\.\d*)?|\.\d+)\.\.[+-]?(\d+(\.\d*)?|\.\d+)')

def single_condition(cond):
"""Helper function to create a condition function for a single value."""
if isinstance(cond, (int, float)):
return lambda col: col == float(cond)
if cond.startswith('>='):
return lambda col: col >= float(cond[2:])
elif cond.startswith('<='):
return lambda col: col <= float(cond[2:])
elif cond.startswith('>'):
return lambda col: col > float(cond[1:])
elif cond.startswith('<'):
return lambda col: col < float(cond[1:])
elif range_pattern.fullmatch(cond):
start, end = map(float, cond.split('..'))
return lambda col: (col >= start) & (col <= end)
else:
return lambda col: col == float(cond)

if isinstance(val, list):
# If val is a list, create a condition for each value and combine them with logical OR
conditions = [single_condition(v) for v in val]
return lambda col: np.logical_or.reduce([cond(col) for cond in conditions])
else:
return single_condition(val)


def apply_column_filters(products, filters):
"""
Applies column-based filters to a product table.

Parameters
----------
products : `~astropy.table.Table`
The product table to filter.
filters : dict
A dictionary where keys are column names and values are the filter values.

Returns
-------
col_mask : `numpy.ndarray`
A boolean mask indicating which rows of the product table satisfy the filter conditions.
"""
col_mask = np.ones(len(products), dtype=bool) # Start with all True mask

# Applying column-based filters
for colname, vals in filters.items():
if colname not in products.colnames:
raise InvalidQueryError(f"Column '{colname}' not found in product table.")

col_data = products[colname]
# If the column is an integer or float, accept numeric filters
if col_data.dtype.kind in ['i', 'f']: # 'i' for integer, 'f' for float
try:
this_mask = parse_numeric_product_filter(vals)(col_data)
except ValueError:
raise InvalidQueryError(f"Could not parse numeric filter '{vals}' for column '{colname}'.")
else: # Assume string or list filter
if isinstance(vals, str):
vals = [vals]
this_mask = np.isin(col_data, vals)

# Combine the current column mask with the overall mask
col_mask &= this_mask

return col_mask
3 changes: 3 additions & 0 deletions docs/mast/mast_cut.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ To access sector information for a particular coordinate, object, or moving targ
-------------- ------ ------ ---
tess-s0008-1-1 8 1 1
tess-s0034-1-2 34 1 2
tess-s0061-1-2 61 1 2
tess-s0088-1-2 88 1 2

Note that because of the delivery cadence of the
TICA high level science products, later sectors will be available sooner with TICA than with
Expand Down Expand Up @@ -242,6 +244,7 @@ The following example requests SPOC cutouts for a moving target.
tess-s0029-1-4 29 1 4
tess-s0043-3-3 43 3 3
tess-s0044-2-4 44 2 4
tess-s0092-4-3 92 4 3

Note that the moving targets functionality is not currently available for TICA,
so the query will always default to SPOC.
Expand Down
Loading
Loading