Skip to content

Commit 02c157d

Browse files
committed
Define helper functions in utils.py
1 parent 77fb500 commit 02c157d

File tree

1 file changed

+197
-0
lines changed

1 file changed

+197
-0
lines changed

astroquery/eso/utils.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
utils.py: helper functions for the astropy.eso module
3+
"""
4+
from dataclasses import dataclass
5+
from typing import Dict, List, Optional, Union
6+
from astropy.table import Table
7+
8+
DEFAULT_LEAD_COLS_RAW = ['object', 'ra', 'dec', 'dp_id', 'date_obs', 'prog_id']
9+
DEFAULT_LEAD_COLS_PHASE3 = ['target_name', 's_ra', 's_dec', 'dp_id', 'date_obs', 'proposal_id']
10+
11+
12+
@dataclass
13+
class _UserParams:
14+
"""
15+
Parameters set by the user
16+
"""
17+
table_name: str
18+
column_name: str = None
19+
allowed_values: Union[List[str], str] = None
20+
cone_ra: float = None
21+
cone_dec: float = None
22+
cone_radius: float = None
23+
columns: Union[List, str] = None
24+
column_filters: Dict[str, str] = None
25+
top: int = None
26+
order_by: str = ''
27+
order_by_desc: bool = True
28+
count_only: bool = False
29+
query_str_only: bool = False
30+
print_help: bool = False
31+
authenticated: bool = False
32+
33+
34+
def _split_str_as_list_of_str(column_str: str):
35+
if column_str == '':
36+
column_list = []
37+
else:
38+
column_list = list(map(lambda x: x.strip(), column_str.split(',')))
39+
return column_list
40+
41+
42+
def _raise_if_has_deprecated_keys(filters: Optional[Dict[str, str]]) -> bool:
43+
if not filters:
44+
return
45+
46+
if any(k in filters for k in ("box", "coord1", "coord2")):
47+
raise ValueError(
48+
"box, coord1 and coord2 are deprecated; "
49+
"use cone_ra, cone_dec and cone_radius instead."
50+
)
51+
52+
if any(k in filters for k in ("etime", "stime")):
53+
raise ValueError(
54+
"'stime' and 'etime' are deprecated; "
55+
"use instead 'exp_start' together with '<', '>', 'between'. Examples:\n"
56+
"\tcolumn_filters = {'exp_start': '< 2024-01-01'}\n"
57+
"\tcolumn_filters = {'exp_start': '>= 2023-01-01'}\n"
58+
"\tcolumn_filters = {'exp_start': \"between '2023-01-01' and '2024-01-01'\"}\n"
59+
)
60+
61+
62+
def _build_where_constraints(
63+
column_name: str,
64+
allowed_values: Union[List[str], str],
65+
column_filters: Dict[str, str]) -> str:
66+
def _format_helper(av):
67+
if isinstance(av, str):
68+
av = _split_str_as_list_of_str(av)
69+
quoted_values = [f"'{v.strip()}'" for v in av]
70+
return f"{column_name} in ({', '.join(quoted_values)})"
71+
72+
column_filters = column_filters or {}
73+
where_constraints = []
74+
if allowed_values:
75+
where_constraints.append(_format_helper(allowed_values))
76+
77+
where_constraints += [
78+
f"{k} {_adql_sanitize_op_val(v)}" for k, v in column_filters.items()
79+
]
80+
return where_constraints
81+
82+
83+
def reorder_columns(table: Table,
84+
leading_cols: Optional[List[str]] = None):
85+
"""
86+
Reorders the columns of the pased table so that the
87+
colums given by the list leading_cols are first.
88+
If no leading cols are passed, it defaults to
89+
['object', 'ra', 'dec', 'dp_id', 'date_obs']
90+
Returns a table with the columns reordered.
91+
"""
92+
if not isinstance(table, Table):
93+
return table
94+
95+
leading_cols = leading_cols or DEFAULT_LEAD_COLS_RAW
96+
first_cols = []
97+
last_cols = table.colnames[:]
98+
for x in leading_cols:
99+
if x in last_cols:
100+
last_cols.remove(x)
101+
first_cols.append(x)
102+
last_cols = first_cols + last_cols
103+
table = table[last_cols]
104+
return table
105+
106+
107+
def _adql_sanitize_op_val(op_val):
108+
"""
109+
Expected input:
110+
"= 5", "< 3.14", "like '%John Doe%'", "in ('item1', 'item2')"
111+
or just string values like "ESO", "ALMA", "'ALMA'", "John Doe"
112+
113+
Logic:
114+
returns "<operator> <value>" if operator is provided.
115+
Defaults to "= <value>" otherwise.
116+
"""
117+
supported_operators = ["<=", ">=", "!=", "=", ">", "<",
118+
"not like ", "not in ", "not between ",
119+
"like ", "between ", "in "] # order matters
120+
121+
if not isinstance(op_val, str):
122+
return f"= {op_val}"
123+
124+
op_val = op_val.strip()
125+
for s in supported_operators:
126+
if op_val.lower().startswith(s):
127+
operator, value = s, op_val[len(s):].strip()
128+
return f"{operator} {value}"
129+
130+
# Default case: no operator. Assign "="
131+
value = op_val if (op_val.startswith("'") and op_val.endswith("'")) else f"'{op_val}'"
132+
return f"= {value}"
133+
134+
135+
def raise_if_coords_not_valid(cone_ra: Optional[float] = None,
136+
cone_dec: Optional[float] = None,
137+
cone_radius: Optional[float] = None) -> bool:
138+
"""
139+
ra, dec, radius must be either present all three
140+
or absent all three. Moreover, they must be float
141+
"""
142+
are_all_none = (cone_ra is None) and (cone_dec is None) and (cone_radius is None)
143+
are_all_float = isinstance(cone_ra, (float, int)) and \
144+
isinstance(cone_dec, (float, int)) and \
145+
isinstance(cone_radius, (float, int))
146+
is_a_valid_combination = are_all_none or are_all_float
147+
if not is_a_valid_combination:
148+
raise ValueError(
149+
"Either all three (cone_ra, cone_dec, cone_radius) must be present or none.\n"
150+
"Values provided:\n"
151+
f"\tcone_ra = {cone_ra}, cone_dec = {cone_dec}, cone_radius = {cone_radius}"
152+
)
153+
154+
155+
def _py2adql(user_params: _UserParams) -> str:
156+
"""
157+
Return the adql string corresponding to the parameters passed
158+
See adql examples at https://archive.eso.org/tap_obs/examples
159+
"""
160+
up = user_params
161+
query_string = None
162+
columns = up.columns or []
163+
164+
# We assume the coordinates passed are valid
165+
where_circle = []
166+
if up.cone_radius is not None:
167+
where_circle += [
168+
'intersects(s_region, circle(\'ICRS\', '
169+
f'{up.cone_ra}, {up.cone_dec}, {up.cone_radius}))=1']
170+
171+
wc = _build_where_constraints(up.column_name,
172+
up.allowed_values,
173+
up.column_filters) + where_circle
174+
175+
if isinstance(columns, str):
176+
columns = _split_str_as_list_of_str(columns)
177+
if columns is None or len(columns) < 1:
178+
columns = ['*']
179+
if up.count_only:
180+
columns = ['count(*)']
181+
182+
# Build the query
183+
query_string = ', '.join(columns) + ' from ' + up.table_name
184+
if len(wc) > 0:
185+
where_string = ' where ' + ' and '.join(wc)
186+
query_string += where_string
187+
188+
if len(up.order_by) > 0 and not up.count_only:
189+
order_string = ' order by ' + up.order_by + (' desc ' if up.order_by_desc else ' asc ')
190+
query_string += order_string
191+
192+
if up.top is not None:
193+
query_string = f"select top {up.top} " + query_string
194+
else:
195+
query_string = "select " + query_string
196+
197+
return query_string.strip()

0 commit comments

Comments
 (0)