-
-
Notifications
You must be signed in to change notification settings - Fork 422
Support expressions when filtering products by numeric columns #3365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
5240817
5087451
74cee5e
9b13f19
dd38fbd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -472,11 +472,19 @@ def filter_products(self, products, *, extension=None, **filters): | |
extension : string or array, optional | ||
Default is None. Filters by file extension(s), matching any specified extensions. | ||
**filters : | ||
Column-based filters to be applied. | ||
Column-based filters to apply to the products table. | ||
|
||
Each keyword corresponds to a column name in the table, with the argument being one or more | ||
acceptable values for that column. AND logic is applied between filters, OR logic within | ||
each filter set. | ||
For example: type="science", extension=["fits","jpg"] | ||
each filter set. For example: type="science", extension=["fits", "jpg"] | ||
|
||
For columns with numeric data types (int or float), filter values can be expressed | ||
in several ways: | ||
|
||
- A single number: ``size=100`` | ||
- A range in the form "start..end": ``size="100..1000"`` | ||
- A comparison operator followed by a number: ``size=">=1000"`` | ||
- A list of expressions (OR logic): ``size=[100, "500..1000", ">=1500"]`` | ||
|
||
Returns | ||
------- | ||
|
@@ -503,11 +511,21 @@ def filter_products(self, products, *, extension=None, **filters): | |
warnings.warn(f"Column '{colname}' not found in product table.", InputWarning) | ||
continue | ||
|
||
vals = [vals] if isinstance(vals, str) else vals | ||
col_mask = np.isin(products[colname], vals) | ||
col_data = products[colname] | ||
# If the column is an integer or float, accept numeric filters | ||
if col_data.dtype.kind in 'if': | ||
try: | ||
col_mask = utils.parse_numeric_product_filter(vals)(col_data) | ||
except ValueError: | ||
warnings.warn(f"Could not parse numeric filter '{vals}' for column '{colname}'.", InputWarning) | ||
continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not allow the exception to be raised here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed that this should just raise an exception. |
||
else: # Assume string or list filter | ||
if isinstance(vals, str): | ||
vals = [vals] | ||
col_mask = np.isin(col_data, vals) | ||
|
||
filter_mask &= col_mask | ||
|
||
# Return filtered products | ||
return products[filter_mask] | ||
|
||
def download_file(self, uri, *, local_path=None, cache=True, verbose=True): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -545,7 +545,7 @@ def get_product_list_async(self, observations): | |
|
||
def filter_products(self, products, *, mrp_only=False, extension=None, **filters): | ||
""" | ||
Takes an `~astropy.table.Table` of MAST observation data products and filters it based on given filters. | ||
Filters an `~astropy.table.Table` of data products based on given filters. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -556,47 +556,67 @@ def filter_products(self, products, *, mrp_only=False, extension=None, **filters | |
extension : string or array, optional | ||
Default None. Option to filter by file extension. | ||
**filters : | ||
Filters to be applied. Valid filters are all products fields listed | ||
Column-based filters to apply to the products table. Valid filters are all products fields listed | ||
`here <https://masttest.stsci.edu/api/v0/_productsfields.html>`__. | ||
The column name is the keyword, with the argument being one or more acceptable values | ||
for that parameter. | ||
Filter behavior is AND between the filters and OR within a filter set. | ||
For example: productType="SCIENCE",extension=["fits","jpg"] | ||
|
||
Each keyword corresponds to a column name in the table, with the argument being one or more | ||
acceptable values for that column. AND logic is applied between filters, OR logic within | ||
each filter set. | ||
|
||
For example: type="science", extension=["fits", "jpg"] | ||
|
||
For columns with numeric data types (int or float), filter values can be expressed | ||
in several ways: | ||
|
||
- A single number: ``size=100`` | ||
- A range in the form "start..end": ``size="100..1000"`` | ||
- A comparison operator followed by a number: ``size=">=1000"`` | ||
- A list of expressions (OR logic): ``size=[100, "500..1000", ">=1500"]`` | ||
|
||
Returns | ||
------- | ||
response : `~astropy.table.Table` | ||
Filtered table of data products. | ||
""" | ||
|
||
filter_mask = np.full(len(products), True, dtype=bool) | ||
|
||
# Applying the special filters (mrp_only and extension) | ||
# Filter by minimum recommended products (MRP) if specified | ||
if mrp_only: | ||
filter_mask &= (products['productGroupDescription'] == "Minimum Recommended Products") | ||
|
||
# Filter by file extension, if provided | ||
if extension: | ||
if isinstance(extension, str): | ||
extension = [extension] | ||
|
||
mask = np.full(len(products), False, dtype=bool) | ||
for elt in extension: | ||
mask |= [False if isinstance(x, np.ma.core.MaskedConstant) else x.endswith(elt) | ||
for x in products["productFilename"]] | ||
filter_mask &= mask | ||
|
||
# Applying the rest of the filters | ||
extensions = [extension] if isinstance(extension, str) else extension | ||
ext_mask = np.array( | ||
[not isinstance(x, np.ma.core.MaskedConstant) and any(x.endswith(ext) for ext in extensions) | ||
for x in products["productFilename"]], | ||
dtype=bool | ||
) | ||
filter_mask &= ext_mask | ||
|
||
# Applying column-based filters | ||
for colname, vals in filters.items(): | ||
if colname not in products.colnames: | ||
warnings.warn(f"Column '{colname}' not found in product table.", InputWarning) | ||
continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be an exception instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My reasoning here was that the schema of the product table returned by our API may change in the future. Only issuing a warning would prevent a user's code from breaking unexpectedly, but I suppose the output of the function would not be what the user expects either. This is also the precedent set by the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeap, but if the API is changing then their code will have to change anyway, or the astroquery module should change and thus they have to update the astroquery version to keep the same user code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see where you're coming from! It probably would be better to fully alert the user if the column they have been filtering on no longer exists. The latest commit raises an error in both |
||
|
||
if isinstance(vals, str): | ||
vals = [vals] | ||
|
||
mask = np.full(len(products), False, dtype=bool) | ||
for elt in vals: | ||
mask |= (products[colname] == elt) | ||
|
||
filter_mask &= mask | ||
|
||
return products[np.where(filter_mask)] | ||
col_data = products[colname] | ||
# If the column is an integer or float, accept numeric filters | ||
if col_data.dtype.kind in 'if': | ||
try: | ||
col_mask = utils.parse_numeric_product_filter(vals)(col_data) | ||
except ValueError: | ||
warnings.warn(f"Could not parse numeric filter '{vals}' for column '{colname}'.", InputWarning) | ||
continue | ||
else: # Assume string or list filter | ||
if isinstance(vals, str): | ||
vals = [vals] | ||
col_mask = np.isin(col_data, vals) | ||
|
||
filter_mask &= col_mask | ||
|
||
return products[filter_mask] | ||
|
||
def download_file(self, uri, *, local_path=None, base_url=None, cache=True, cloud_only=False, verbose=True): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -371,11 +371,59 @@ def test_missions_filter_products(patch_post): | |
assert isinstance(filtered, Table) | ||
assert all(filtered['category'] == 'CALIBRATED') | ||
|
||
# Filter by extension | ||
filtered = mast.MastMissions.filter_products(products, | ||
extension='fits') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick, but we allow linelength to run up to 120, there is really no need for the break here, and below |
||
assert len(filtered) > 0 | ||
|
||
# Filter by non-existing column | ||
with pytest.warns(InputWarning): | ||
mast.MastMissions.filter_products(products, | ||
invalid=True) | ||
|
||
# Numeric filtering | ||
# Single integer value | ||
filtered = mast.MastMissions.filter_products(products, | ||
size=11520) | ||
assert all(filtered['size'] == 11520) | ||
|
||
# Single string value | ||
filtered = mast.MastMissions.filter_products(products, | ||
size='11520') | ||
assert all(filtered['size'] == 11520) | ||
|
||
# Comparison operators | ||
filtered = mast.MastMissions.filter_products(products, | ||
size='<15000') | ||
assert all(filtered['size'] < 15000) | ||
|
||
filtered = mast.MastMissions.filter_products(products, | ||
size='>15000') | ||
assert all(filtered['size'] > 15000) | ||
|
||
filtered = mast.MastMissions.filter_products(products, | ||
size='>=14400') | ||
assert all(filtered['size'] >= 14400) | ||
|
||
filtered = mast.MastMissions.filter_products(products, | ||
size='<=14400') | ||
assert all(filtered['size'] <= 14400) | ||
|
||
# Range operator | ||
filtered = mast.MastMissions.filter_products(products, | ||
size='14400..17280') | ||
assert all((filtered['size'] >= 14400) & (filtered['size'] <= 17280)) | ||
|
||
# List of expressions | ||
filtered = mast.MastMissions.filter_products(products, | ||
size=[14400, '>20000']) | ||
assert all((filtered['size'] == 14400) | (filtered['size'] > 20000)) | ||
|
||
with pytest.warns(InputWarning, match="Could not parse numeric filter 'invalid' for column 'size'"): | ||
# Invalid filter value | ||
mast.MastMissions.filter_products(products, | ||
size='invalid') | ||
|
||
|
||
def test_missions_download_products(patch_post, tmp_path): | ||
# Check string input | ||
|
@@ -670,11 +718,36 @@ def test_observations_get_product_list(patch_post): | |
|
||
def test_observations_filter_products(patch_post): | ||
products = mast.Observations.get_product_list('2003738726') | ||
result = mast.Observations.filter_products(products, | ||
productType=["SCIENCE"], | ||
mrp_only=False) | ||
assert isinstance(result, Table) | ||
assert len(result) == 7 | ||
filtered = mast.Observations.filter_products(products, | ||
productType=["SCIENCE"], | ||
mrp_only=False) | ||
assert isinstance(filtered, Table) | ||
assert len(filtered) == 7 | ||
|
||
# Filter for minimum recommended products | ||
filtered = mast.Observations.filter_products(products, | ||
mrp_only=True) | ||
assert all(filtered['productGroupDescription'] == 'Minimum Recommended Products') | ||
|
||
# Filter by extension | ||
filtered = mast.Observations.filter_products(products, | ||
extension='fits') | ||
assert len(filtered) > 0 | ||
|
||
# Filter by non-existing column | ||
with pytest.warns(InputWarning): | ||
mast.Observations.filter_products(products, | ||
invalid=True) | ||
|
||
# Numeric filtering | ||
filtered = mast.Observations.filter_products(products, | ||
size='<50000') | ||
assert all(filtered['size'] < 50000) | ||
|
||
# Numeric filter that cannot be parsed | ||
with pytest.warns(InputWarning, match="Could not parse numeric filter 'invalid' for column 'size'"): | ||
filtered = mast.Observations.filter_products(products, | ||
size='invalid') | ||
|
||
|
||
def test_observations_download_products(patch_post, tmpdir): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in 'if'
? I'm not sure I get thatThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's checking whether the
kind
code of the column isi
(integer) orf
(float).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add that as a comment? Or even better, rephrase the line as