Skip to content

Commit 8a7e62f

Browse files
Jorge Fernandez HernandezJorge Fernandez Hernandez
authored andcommitted
GAIASWRQ-25 The methods cross_match and cross_match_basic accept radius as a Quantiy
1 parent 5f7adff commit 8a7e62f

File tree

2 files changed

+63
-14
lines changed

2 files changed

+63
-14
lines changed

astroquery/gaia/core.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def __getQuantityInput(self, value, msg):
790790
if value is None:
791791
raise ValueError(f"Missing required argument: {msg}")
792792
if not (isinstance(value, str) or isinstance(value, units.Quantity)):
793-
raise ValueError(f"{msg} must be either a string or astropy.coordinates")
793+
raise ValueError(f"{msg} must be either a string or astropy.coordinates: {type(value)}")
794794

795795
if isinstance(value, str):
796796
return Quantity(value)
@@ -883,8 +883,9 @@ def cross_match_basic(self, *, table_a_full_qualified_name, table_a_column_ra, t
883883
the ‘dec’ column in the table table_b_full_qualified_name
884884
results_name : str, optional, default None
885885
custom name defined by the user for the job that is going to be created
886-
radius : float (arc. seconds), optional, default 1.0
887-
radius (valid range: 0.1-10.0)
886+
radius : float (arc. seconds), str or astropy.coordinate, optional, default 1.0
887+
radius (valid range: 0.1-10.0). For an astropy.coordinate any angular unit is valid, but its value in arc
888+
sec must be contained within the valid range.
888889
background : bool, optional, default 'False'
889890
when the job is executed in asynchronous mode, this flag specifies
890891
whether the execution will wait until results are available
@@ -896,8 +897,12 @@ def cross_match_basic(self, *, table_a_full_qualified_name, table_a_column_ra, t
896897
A Job object
897898
"""
898899

899-
if radius < 0.1 or radius > 10.0:
900-
raise ValueError(f"Invalid radius value. Found {radius}, valid range is: 0.1 to 10.0")
900+
radius_quantity = self.__get_radius_as_quantity_arcsec(radius)
901+
902+
radius_arc_sec = radius_quantity.value
903+
904+
if radius_arc_sec < 0.1 or radius_arc_sec > 10.0:
905+
raise ValueError(f"Invalid radius value. Found {radius_quantity}, valid range is: 0.1 to 10.0")
901906

902907
schema_a = self.__get_schema_name(table_a_full_qualified_name)
903908
if not schema_a:
@@ -928,7 +933,7 @@ def cross_match_basic(self, *, table_a_full_qualified_name, table_a_column_ra, t
928933
f"b.{table_b_column_dec}) AS separation, b.* "
929934
f"FROM {table_a_full_qualified_name} AS a JOIN {table_b_full_qualified_name} AS b "
930935
f"ON DISTANCE(a.{table_a_column_ra}, a.{table_a_column_dec}, b.{table_b_column_ra}, b.{table_b_column_dec})"
931-
f" < {radius} / 3600.")
936+
f" < {radius_quantity.to(u.deg).value}")
932937

933938
return self.launch_job_async(query=query,
934939
name=results_name,
@@ -940,6 +945,16 @@ def cross_match_basic(self, *, table_a_full_qualified_name, table_a_column_ra, t
940945
upload_resource=None,
941946
upload_table_name=None)
942947

948+
def __get_radius_as_quantity_arcsec(self, radius):
949+
"""
950+
transform the input radius into an astropy.Quantity in arc seconds
951+
"""
952+
if not isinstance(radius, units.Quantity):
953+
radius_quantity = Quantity(value=radius, unit=u.arcsec)
954+
else:
955+
radius_quantity = radius.to(u.arcsec)
956+
return radius_quantity
957+
943958
def __update_ra_dec_columns(self, full_qualified_table_name, column_ra, column_dec, table_metadata, verbose):
944959
"""
945960
Update table metadata for the ‘ra’ and the ‘dec’ columns in the input table
@@ -1007,8 +1022,9 @@ def cross_match(self, *, full_qualified_table_name_a,
10071022
a full qualified table name (i.e. schema name and table name)
10081023
results_table_name : str, mandatory
10091024
a table name without schema. The schema is set to the user one
1010-
radius : float (arc. seconds), optional, default 1.0
1011-
radius (valid range: 0.1-10.0)
1025+
radius : float (arc. seconds), str or astropy.coordinate, optional, default 1.0
1026+
radius (valid range: 0.1-10.0). For an astropy.coordinate any angular unit is valid, but its value in arc
1027+
sec must be contained within the valid range.
10121028
background : bool, optional, default 'False'
10131029
when the job is executed in asynchronous mode, this flag specifies
10141030
whether the execution will wait until results are available
@@ -1019,8 +1035,13 @@ def cross_match(self, *, full_qualified_table_name_a,
10191035
-------
10201036
A Job object
10211037
"""
1022-
if radius < 0.1 or radius > 10.0:
1023-
raise ValueError(f"Invalid radius value. Found {radius}, valid range is: 0.1 to 10.0")
1038+
1039+
radius_quantity = self.__get_radius_as_quantity_arcsec(radius)
1040+
1041+
radius_arc_sec = radius_quantity.value
1042+
1043+
if radius_arc_sec < 0.1 or radius_arc_sec > 10.0:
1044+
raise ValueError(f"Invalid radius value. Found {radius_quantity}, valid range is: 0.1 to 10.0")
10241045

10251046
schema_a = self.__get_schema_name(full_qualified_table_name_a)
10261047

@@ -1033,7 +1054,7 @@ def cross_match(self, *, full_qualified_table_name_a,
10331054
if taputils.get_schema_name(results_table_name) is not None:
10341055
raise ValueError("Please, do not specify schema for 'results_table_name'")
10351056

1036-
query = f"SELECT crossmatch_positional('{schema_a}','{table_a}','{schema_b}','{table_b}',{radius}, " \
1057+
query = f"SELECT crossmatch_positional('{schema_a}','{table_a}','{schema_b}','{table_b}',{radius_arc_sec}, " \
10371058
f"'{results_table_name}') FROM dual;"
10381059

10391060
name = str(results_table_name)

astroquery/gaia/tests/test_gaiatap.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import pytest
2626
from astropy.coordinates.sky_coordinate import SkyCoord
2727
from astropy.table import Column, Table
28+
from astropy.units import Quantity
2829
from astropy.utils.data import get_pkg_data_filename
2930
from astropy.utils.exceptions import AstropyDeprecationWarning
3031
from requests import HTTPError
@@ -1361,7 +1362,7 @@ def test_cross_match_invalid_mandatory_kwarg(cross_match_kwargs, kwarg, invalid_
13611362
def test_cross_match_invalid_radius(cross_match_kwargs, radius):
13621363
with pytest.raises(
13631364
ValueError,
1364-
match=rf"^Invalid radius value. Found {radius}, valid range is: 0.1 to 10.0$",
1365+
match=rf"^Invalid radius value. Found {radius} arcsec, valid range is: 0.1 to 10.0$",
13651366
):
13661367
GAIA_QUERIER.cross_match(**cross_match_kwargs, radius=radius)
13671368

@@ -1450,6 +1451,20 @@ def update_user_table(self, table_name, list_of_changes, verbose):
14501451
assert job.get_phase() == "EXECUTING" if background else "COMPLETED"
14511452
assert job.failed is False
14521453

1454+
radius_quantity = Quantity(value=1.0, unit=u.arcsec)
1455+
job = mock_querier_async.cross_match_basic(table_a_full_qualified_name="user_hola.tableA", table_a_column_ra="ra",
1456+
table_a_column_dec="dec", radius=radius_quantity, background=background)
1457+
assert job.async_ is True
1458+
assert job.get_phase() == "EXECUTING" if background else "COMPLETED"
1459+
assert job.failed is False
1460+
1461+
radius_quantity = Quantity(value=1.0/3600.0, unit=u.deg)
1462+
job = mock_querier_async.cross_match_basic(table_a_full_qualified_name="user_hola.tableA", table_a_column_ra="ra",
1463+
table_a_column_dec="dec", radius=radius_quantity, background=background)
1464+
assert job.async_ is True
1465+
assert job.get_phase() == "EXECUTING" if background else "COMPLETED"
1466+
assert job.failed is False
1467+
14531468

14541469
@pytest.mark.parametrize("background", [False, True])
14551470
def test_cross_match_basic_wrong_column(monkeypatch, background, mock_querier_async):
@@ -1517,18 +1532,31 @@ def update_user_table(self, table_name, list_of_changes, verbose):
15171532
GAIA_QUERIER.cross_match_basic(table_a_full_qualified_name="schema.table_name", table_a_column_ra="ra",
15181533
table_a_column_dec="dec", table_b_full_qualified_name=".table_name")
15191534

1520-
error_message = "Invalid radius value. Found 50.0, valid range is: 0.1 to 10.0"
1535+
error_message = "Invalid radius value. Found 50.0 arcsec, valid range is: 0.1 to 10.0"
15211536
with pytest.raises(ValueError, match=error_message):
15221537
GAIA_QUERIER.cross_match_basic(table_a_full_qualified_name="schema.table_name", table_a_column_ra="ra",
15231538
table_a_column_dec="dec", table_b_full_qualified_name="schema.table_name",
15241539
radius=50.0)
15251540

1526-
error_message = "Invalid radius value. Found 0.01, valid range is: 0.1 to 10.0"
1541+
error_message = "Invalid radius value. Found 0.01 arcsec, valid range is: 0.1 to 10.0"
15271542
with pytest.raises(ValueError, match=error_message):
15281543
GAIA_QUERIER.cross_match_basic(table_a_full_qualified_name="schema.table_name", table_a_column_ra="ra",
15291544
table_a_column_dec="dec", table_b_full_qualified_name="schema.table_name",
15301545
radius=0.01)
15311546

1547+
radius_quantity = Quantity(value=0.01, unit=u.arcsec)
1548+
with pytest.raises(ValueError, match=error_message):
1549+
GAIA_QUERIER.cross_match_basic(table_a_full_qualified_name="schema.table_name", table_a_column_ra="ra",
1550+
table_a_column_dec="dec", table_b_full_qualified_name="schema.table_name",
1551+
radius=radius_quantity)
1552+
1553+
radius_quantity = Quantity(value=1.0, unit=u.deg)
1554+
error_message = "Invalid radius value. Found 3600.0 arcsec, valid range is: 0.1 to 10.0"
1555+
with pytest.raises(ValueError, match=error_message):
1556+
GAIA_QUERIER.cross_match_basic(table_a_full_qualified_name="schema.table_name", table_a_column_ra="ra",
1557+
table_a_column_dec="dec", table_b_full_qualified_name="schema.table_name",
1558+
radius=radius_quantity)
1559+
15321560

15331561
@patch.object(TapPlus, 'login')
15341562
def test_login(mock_login):

0 commit comments

Comments
 (0)