|
| 1 | +""" |
| 2 | +utils.py: helper functions for the astropy.eso module |
| 3 | +""" |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import Dict, List, Optional, Union |
| 6 | +from astropy.table import Table |
| 7 | + |
| 8 | +DEFAULT_LEAD_COLS_RAW = ['object', 'ra', 'dec', 'dp_id', 'date_obs', 'prog_id'] |
| 9 | +DEFAULT_LEAD_COLS_PHASE3 = ['target_name', 's_ra', 's_dec', 'dp_id', 'date_obs', 'proposal_id'] |
| 10 | + |
| 11 | + |
| 12 | +@dataclass |
| 13 | +class _UserParams: |
| 14 | + """ |
| 15 | + Parameters set by the user |
| 16 | + """ |
| 17 | + table_name: str |
| 18 | + column_name: str = None |
| 19 | + allowed_values: Union[List[str], str] = None |
| 20 | + cone_ra: float = None |
| 21 | + cone_dec: float = None |
| 22 | + cone_radius: float = None |
| 23 | + columns: Union[List, str] = None |
| 24 | + column_filters: Dict[str, str] = None |
| 25 | + top: int = None |
| 26 | + order_by: str = '' |
| 27 | + order_by_desc: bool = True |
| 28 | + count_only: bool = False |
| 29 | + query_str_only: bool = False |
| 30 | + print_help: bool = False |
| 31 | + authenticated: bool = False |
| 32 | + |
| 33 | + |
| 34 | +def _split_str_as_list_of_str(column_str: str): |
| 35 | + if column_str == '': |
| 36 | + column_list = [] |
| 37 | + else: |
| 38 | + column_list = list(map(lambda x: x.strip(), column_str.split(','))) |
| 39 | + return column_list |
| 40 | + |
| 41 | + |
| 42 | +def _raise_if_has_deprecated_keys(filters: Optional[Dict[str, str]]) -> bool: |
| 43 | + if not filters: |
| 44 | + return |
| 45 | + |
| 46 | + if any(k in filters for k in ("box", "coord1", "coord2")): |
| 47 | + raise ValueError( |
| 48 | + "box, coord1 and coord2 are deprecated; " |
| 49 | + "use cone_ra, cone_dec and cone_radius instead." |
| 50 | + ) |
| 51 | + |
| 52 | + if any(k in filters for k in ("etime", "stime")): |
| 53 | + raise ValueError( |
| 54 | + "'stime' and 'etime' are deprecated; " |
| 55 | + "use instead 'exp_start' together with '<', '>', 'between'. Examples:\n" |
| 56 | + "\tcolumn_filters = {'exp_start': '< 2024-01-01'}\n" |
| 57 | + "\tcolumn_filters = {'exp_start': '>= 2023-01-01'}\n" |
| 58 | + "\tcolumn_filters = {'exp_start': \"between '2023-01-01' and '2024-01-01'\"}\n" |
| 59 | + ) |
| 60 | + |
| 61 | + |
| 62 | +def _build_where_constraints( |
| 63 | + column_name: str, |
| 64 | + allowed_values: Union[List[str], str], |
| 65 | + column_filters: Dict[str, str]) -> str: |
| 66 | + def _format_helper(av): |
| 67 | + if isinstance(av, str): |
| 68 | + av = _split_str_as_list_of_str(av) |
| 69 | + quoted_values = [f"'{v.strip()}'" for v in av] |
| 70 | + return f"{column_name} in ({', '.join(quoted_values)})" |
| 71 | + |
| 72 | + column_filters = column_filters or {} |
| 73 | + where_constraints = [] |
| 74 | + if allowed_values: |
| 75 | + where_constraints.append(_format_helper(allowed_values)) |
| 76 | + |
| 77 | + where_constraints += [ |
| 78 | + f"{k} {_adql_sanitize_op_val(v)}" for k, v in column_filters.items() |
| 79 | + ] |
| 80 | + return where_constraints |
| 81 | + |
| 82 | + |
| 83 | +def reorder_columns(table: Table, |
| 84 | + leading_cols: Optional[List[str]] = None): |
| 85 | + """ |
| 86 | + Reorders the columns of the pased table so that the |
| 87 | + colums given by the list leading_cols are first. |
| 88 | + If no leading cols are passed, it defaults to |
| 89 | + ['object', 'ra', 'dec', 'dp_id', 'date_obs'] |
| 90 | + Returns a table with the columns reordered. |
| 91 | + """ |
| 92 | + if not isinstance(table, Table): |
| 93 | + return table |
| 94 | + |
| 95 | + leading_cols = leading_cols or DEFAULT_LEAD_COLS_RAW |
| 96 | + first_cols = [] |
| 97 | + last_cols = table.colnames[:] |
| 98 | + for x in leading_cols: |
| 99 | + if x in last_cols: |
| 100 | + last_cols.remove(x) |
| 101 | + first_cols.append(x) |
| 102 | + last_cols = first_cols + last_cols |
| 103 | + table = table[last_cols] |
| 104 | + return table |
| 105 | + |
| 106 | + |
| 107 | +def _adql_sanitize_op_val(op_val): |
| 108 | + """ |
| 109 | + Expected input: |
| 110 | + "= 5", "< 3.14", "like '%John Doe%'", "in ('item1', 'item2')" |
| 111 | + or just string values like "ESO", "ALMA", "'ALMA'", "John Doe" |
| 112 | +
|
| 113 | + Logic: |
| 114 | + returns "<operator> <value>" if operator is provided. |
| 115 | + Defaults to "= <value>" otherwise. |
| 116 | + """ |
| 117 | + supported_operators = ["<=", ">=", "!=", "=", ">", "<", |
| 118 | + "not like ", "not in ", "not between ", |
| 119 | + "like ", "between ", "in "] # order matters |
| 120 | + |
| 121 | + if not isinstance(op_val, str): |
| 122 | + return f"= {op_val}" |
| 123 | + |
| 124 | + op_val = op_val.strip() |
| 125 | + for s in supported_operators: |
| 126 | + if op_val.lower().startswith(s): |
| 127 | + operator, value = s, op_val[len(s):].strip() |
| 128 | + return f"{operator} {value}" |
| 129 | + |
| 130 | + # Default case: no operator. Assign "=" |
| 131 | + value = op_val if (op_val.startswith("'") and op_val.endswith("'")) else f"'{op_val}'" |
| 132 | + return f"= {value}" |
| 133 | + |
| 134 | + |
| 135 | +def raise_if_coords_not_valid(cone_ra: Optional[float] = None, |
| 136 | + cone_dec: Optional[float] = None, |
| 137 | + cone_radius: Optional[float] = None) -> bool: |
| 138 | + """ |
| 139 | + ra, dec, radius must be either present all three |
| 140 | + or absent all three. Moreover, they must be float |
| 141 | + """ |
| 142 | + are_all_none = (cone_ra is None) and (cone_dec is None) and (cone_radius is None) |
| 143 | + are_all_float = isinstance(cone_ra, (float, int)) and \ |
| 144 | + isinstance(cone_dec, (float, int)) and \ |
| 145 | + isinstance(cone_radius, (float, int)) |
| 146 | + is_a_valid_combination = are_all_none or are_all_float |
| 147 | + if not is_a_valid_combination: |
| 148 | + raise ValueError( |
| 149 | + "Either all three (cone_ra, cone_dec, cone_radius) must be present or none.\n" |
| 150 | + "Values provided:\n" |
| 151 | + f"\tcone_ra = {cone_ra}, cone_dec = {cone_dec}, cone_radius = {cone_radius}" |
| 152 | + ) |
| 153 | + |
| 154 | + |
| 155 | +def _py2adql(user_params: _UserParams) -> str: |
| 156 | + """ |
| 157 | + Return the adql string corresponding to the parameters passed |
| 158 | + See adql examples at https://archive.eso.org/tap_obs/examples |
| 159 | + """ |
| 160 | + up = user_params |
| 161 | + query_string = None |
| 162 | + columns = up.columns or [] |
| 163 | + |
| 164 | + # We assume the coordinates passed are valid |
| 165 | + where_circle = [] |
| 166 | + if up.cone_radius is not None: |
| 167 | + where_circle += [ |
| 168 | + 'intersects(s_region, circle(\'ICRS\', ' |
| 169 | + f'{up.cone_ra}, {up.cone_dec}, {up.cone_radius}))=1'] |
| 170 | + |
| 171 | + wc = _build_where_constraints(up.column_name, |
| 172 | + up.allowed_values, |
| 173 | + up.column_filters) + where_circle |
| 174 | + |
| 175 | + if isinstance(columns, str): |
| 176 | + columns = _split_str_as_list_of_str(columns) |
| 177 | + if columns is None or len(columns) < 1: |
| 178 | + columns = ['*'] |
| 179 | + if up.count_only: |
| 180 | + columns = ['count(*)'] |
| 181 | + |
| 182 | + # Build the query |
| 183 | + query_string = ', '.join(columns) + ' from ' + up.table_name |
| 184 | + if len(wc) > 0: |
| 185 | + where_string = ' where ' + ' and '.join(wc) |
| 186 | + query_string += where_string |
| 187 | + |
| 188 | + if len(up.order_by) > 0 and not up.count_only: |
| 189 | + order_string = ' order by ' + up.order_by + (' desc ' if up.order_by_desc else ' asc ') |
| 190 | + query_string += order_string |
| 191 | + |
| 192 | + if up.top is not None: |
| 193 | + query_string = f"select top {up.top} " + query_string |
| 194 | + else: |
| 195 | + query_string = "select " + query_string |
| 196 | + |
| 197 | + return query_string.strip() |
0 commit comments