Skip to content

Commit 9b13f19

Browse files
committed
refactor logic into utils, raise error on failed integer parsing, minor style fixes
1 parent 74cee5e commit 9b13f19

File tree

4 files changed

+69
-79
lines changed

4 files changed

+69
-79
lines changed

astroquery/mast/missions.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from astroquery.utils import commons, async_to_sync
2323
from astroquery.utils.class_or_instance import class_or_instance
2424
from astropy.utils.console import ProgressBarOrSpinner
25-
from astroquery.exceptions import InvalidQueryError, MaxResultsWarning, InputWarning, NoResultsWarning
25+
from astroquery.exceptions import InvalidQueryError, MaxResultsWarning, NoResultsWarning
2626

2727
from astroquery.mast import utils
2828
from astroquery.mast.core import MastQueryWithLogin
@@ -505,26 +505,9 @@ def filter_products(self, products, *, extension=None, **filters):
505505
)
506506
filter_mask &= ext_mask
507507

508-
# Applying column-based filters
509-
for colname, vals in filters.items():
510-
if colname not in products.colnames:
511-
warnings.warn(f"Column '{colname}' not found in product table.", InputWarning)
512-
continue
513-
514-
col_data = products[colname]
515-
# If the column is an integer or float, accept numeric filters
516-
if col_data.dtype.kind in 'if':
517-
try:
518-
col_mask = utils.parse_numeric_product_filter(vals)(col_data)
519-
except ValueError:
520-
warnings.warn(f"Could not parse numeric filter '{vals}' for column '{colname}'.", InputWarning)
521-
continue
522-
else: # Assume string or list filter
523-
if isinstance(vals, str):
524-
vals = [vals]
525-
col_mask = np.isin(col_data, vals)
526-
527-
filter_mask &= col_mask
508+
# Apply column-based filters
509+
col_mask = utils.apply_column_filters(products, filters)
510+
filter_mask &= col_mask
528511

529512
return products[filter_mask]
530513

astroquery/mast/observations.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -595,26 +595,9 @@ def filter_products(self, products, *, mrp_only=False, extension=None, **filters
595595
)
596596
filter_mask &= ext_mask
597597

598-
# Applying column-based filters
599-
for colname, vals in filters.items():
600-
if colname not in products.colnames:
601-
warnings.warn(f"Column '{colname}' not found in product table.", InputWarning)
602-
continue
603-
604-
col_data = products[colname]
605-
# If the column is an integer or float, accept numeric filters
606-
if col_data.dtype.kind in 'if':
607-
try:
608-
col_mask = utils.parse_numeric_product_filter(vals)(col_data)
609-
except ValueError:
610-
warnings.warn(f"Could not parse numeric filter '{vals}' for column '{colname}'.", InputWarning)
611-
continue
612-
else: # Assume string or list filter
613-
if isinstance(vals, str):
614-
vals = [vals]
615-
col_mask = np.isin(col_data, vals)
616-
617-
filter_mask &= col_mask
598+
# Apply column-based filters
599+
col_mask = utils.apply_column_filters(products, filters)
600+
filter_mask &= col_mask
618601

619602
return products[filter_mask]
620603

astroquery/mast/tests/test_mast.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -366,63 +366,51 @@ def test_missions_get_unique_product_list(patch_post, caplog):
366366
def test_missions_filter_products(patch_post):
367367
# Filter products list by column
368368
products = mast.MastMissions.get_product_list('Z14Z0104T')
369-
filtered = mast.MastMissions.filter_products(products,
370-
category='CALIBRATED')
369+
filtered = mast.MastMissions.filter_products(products, category='CALIBRATED')
371370
assert isinstance(filtered, Table)
372371
assert all(filtered['category'] == 'CALIBRATED')
373372

374373
# Filter by extension
375-
filtered = mast.MastMissions.filter_products(products,
376-
extension='fits')
374+
filtered = mast.MastMissions.filter_products(products, extension='fits')
377375
assert len(filtered) > 0
378376

379377
# Filter by non-existing column
380378
with pytest.warns(InputWarning):
381-
mast.MastMissions.filter_products(products,
382-
invalid=True)
379+
mast.MastMissions.filter_products(products, invalid=True)
383380

384381
# Numeric filtering
385382
# Single integer value
386-
filtered = mast.MastMissions.filter_products(products,
387-
size=11520)
383+
filtered = mast.MastMissions.filter_products(products, size=11520)
388384
assert all(filtered['size'] == 11520)
389385

390386
# Single string value
391-
filtered = mast.MastMissions.filter_products(products,
392-
size='11520')
387+
filtered = mast.MastMissions.filter_products(products, size='11520')
393388
assert all(filtered['size'] == 11520)
394389

395390
# Comparison operators
396-
filtered = mast.MastMissions.filter_products(products,
397-
size='<15000')
391+
filtered = mast.MastMissions.filter_products(products, size='<15000')
398392
assert all(filtered['size'] < 15000)
399393

400-
filtered = mast.MastMissions.filter_products(products,
401-
size='>15000')
394+
filtered = mast.MastMissions.filter_products(products, size='>15000')
402395
assert all(filtered['size'] > 15000)
403396

404-
filtered = mast.MastMissions.filter_products(products,
405-
size='>=14400')
397+
filtered = mast.MastMissions.filter_products(products, size='>=14400')
406398
assert all(filtered['size'] >= 14400)
407399

408-
filtered = mast.MastMissions.filter_products(products,
409-
size='<=14400')
400+
filtered = mast.MastMissions.filter_products(products, size='<=14400')
410401
assert all(filtered['size'] <= 14400)
411402

412403
# Range operator
413-
filtered = mast.MastMissions.filter_products(products,
414-
size='14400..17280')
404+
filtered = mast.MastMissions.filter_products(products, size='14400..17280')
415405
assert all((filtered['size'] >= 14400) & (filtered['size'] <= 17280))
416406

417407
# List of expressions
418-
filtered = mast.MastMissions.filter_products(products,
419-
size=[14400, '>20000'])
408+
filtered = mast.MastMissions.filter_products(products, size=[14400, '>20000'])
420409
assert all((filtered['size'] == 14400) | (filtered['size'] > 20000))
421410

422-
with pytest.warns(InputWarning, match="Could not parse numeric filter 'invalid' for column 'size'"):
411+
with pytest.raises(InvalidQueryError, match="Could not parse numeric filter 'invalid' for column 'size'"):
423412
# Invalid filter value
424-
mast.MastMissions.filter_products(products,
425-
size='invalid')
413+
mast.MastMissions.filter_products(products, size='invalid')
426414

427415

428416
def test_missions_download_products(patch_post, tmp_path):
@@ -725,29 +713,24 @@ def test_observations_filter_products(patch_post):
725713
assert len(filtered) == 7
726714

727715
# Filter for minimum recommended products
728-
filtered = mast.Observations.filter_products(products,
729-
mrp_only=True)
716+
filtered = mast.Observations.filter_products(products, mrp_only=True)
730717
assert all(filtered['productGroupDescription'] == 'Minimum Recommended Products')
731718

732719
# Filter by extension
733-
filtered = mast.Observations.filter_products(products,
734-
extension='fits')
720+
filtered = mast.Observations.filter_products(products, extension='fits')
735721
assert len(filtered) > 0
736722

737723
# Filter by non-existing column
738724
with pytest.warns(InputWarning):
739-
mast.Observations.filter_products(products,
740-
invalid=True)
725+
mast.Observations.filter_products(products, invalid=True)
741726

742727
# Numeric filtering
743-
filtered = mast.Observations.filter_products(products,
744-
size='<50000')
728+
filtered = mast.Observations.filter_products(products, size='<50000')
745729
assert all(filtered['size'] < 50000)
746730

747731
# Numeric filter that cannot be parsed
748-
with pytest.warns(InputWarning, match="Could not parse numeric filter 'invalid' for column 'size'"):
749-
filtered = mast.Observations.filter_products(products,
750-
size='invalid')
732+
with pytest.raises(InvalidQueryError, match="Could not parse numeric filter 'invalid' for column 'size'"):
733+
filtered = mast.Observations.filter_products(products, size='invalid')
751734

752735

753736
def test_observations_download_products(patch_post, tmpdir):
@@ -775,8 +758,7 @@ def test_observations_download_products(patch_post, tmpdir):
775758

776759
# passing row product
777760
products = mast.Observations.get_product_list('2003738726')
778-
result1 = mast.Observations.download_products(products[0],
779-
download_dir=str(tmpdir))
761+
result1 = mast.Observations.download_products(products[0], download_dir=str(tmpdir))
780762
assert isinstance(result1, Table)
781763

782764

astroquery/mast/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,45 @@ def single_condition(cond):
393393
return lambda col: np.logical_or.reduce([cond(col) for cond in conditions])
394394
else:
395395
return single_condition(val)
396+
397+
398+
def apply_column_filters(products, filters):
399+
"""
400+
Applies column-based filters to a product table.
401+
402+
Parameters
403+
----------
404+
products : `~astropy.table.Table`
405+
The product table to filter.
406+
filters : dict
407+
A dictionary where keys are column names and values are the filter values.
408+
409+
Returns
410+
-------
411+
col_mask : `numpy.ndarray`
412+
A boolean mask indicating which rows of the product table satisfy the filter conditions.
413+
"""
414+
col_mask = np.ones(len(products), dtype=bool) # Start with all True mask
415+
416+
# Applying column-based filters
417+
for colname, vals in filters.items():
418+
if colname not in products.colnames:
419+
warnings.warn(f"Column '{colname}' not found in product table.", InputWarning)
420+
continue
421+
422+
col_data = products[colname]
423+
# If the column is an integer or float, accept numeric filters
424+
if col_data.dtype.kind in ['i', 'f']: # 'i' for integer, 'f' for float
425+
try:
426+
this_mask = parse_numeric_product_filter(vals)(col_data)
427+
except ValueError:
428+
raise InvalidQueryError(f"Could not parse numeric filter '{vals}' for column '{colname}'.")
429+
else: # Assume string or list filter
430+
if isinstance(vals, str):
431+
vals = [vals]
432+
this_mask = np.isin(col_data, vals)
433+
434+
# Combine the current column mask with the overall mask
435+
col_mask &= this_mask
436+
437+
return col_mask

0 commit comments

Comments
 (0)