|
| 1 | +# Copyright 2025 AstroLab Software |
| 2 | +# Author: Julien Peloton |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +"""Get ephemerides at scales""" |
| 16 | + |
| 17 | +import pandas as pd |
| 18 | + |
| 19 | +from pyspark.sql.functions import pandas_udf, PandasUDFType |
| 20 | +from pyspark.sql.types import MapType, StringType, FloatType, ArrayType |
| 21 | + |
| 22 | +from fink_utils.sso.utils import query_miriade_ephemcc |
| 23 | +from fink_utils.sso.utils import query_miriade |
| 24 | + |
| 25 | +from fink_utils.tester import spark_unit_tests |
| 26 | + |
| 27 | + |
| 28 | +COLUMNS = ["Dobs", "Dhelio", "SDSS:g", "SDSS:r", "Phase", "Elong."] |
| 29 | + |
| 30 | + |
| 31 | +def sanitize_name(col): |
| 32 | + """Remove trailing '.' from names""" |
| 33 | + return col.replace(".", "") |
| 34 | + |
| 35 | + |
| 36 | +def expand_columns(df, col_to_expand="ephem"): |
| 37 | + """Expand a MapType column into individual columns |
| 38 | +
|
| 39 | + Notes |
| 40 | + ----- |
| 41 | + The operation will transform a dataframe with columns |
| 42 | + ["toto", "container.col1", container.col2] to a dataframe with columns |
| 43 | + ["toto", "col1", "col2"] |
| 44 | + Note that `col_to_expand` is dropped. |
| 45 | +
|
| 46 | + Parameters |
| 47 | + ---------- |
| 48 | + df: Spark DataFrame |
| 49 | + Spark DataFrame with the map column |
| 50 | + col_to_expand: str |
| 51 | + Name of the column to expand |
| 52 | +
|
| 53 | + Returns |
| 54 | + ------- |
| 55 | + out: Spark DataFrame |
| 56 | + The expanded input DataFrame |
| 57 | +
|
| 58 | + Examples |
| 59 | + -------- |
| 60 | + >>> pdf = pd.DataFrame({"a": [{"Dobs": 1, "Elong": 2}, {"Dobs": 10, "Elong": 20}]}) |
| 61 | + >>> df = spark.createDataFrame(pdf) |
| 62 | + >>> assert "a" in df.columns, df.columns |
| 63 | + >>> assert "Dobs" not in df.columns, df.columns |
| 64 | + >>> df = expand_columns(df, col_to_expand="a") |
| 65 | + >>> assert "Dobs" in df.columns, df.columns |
| 66 | + >>> assert "a" not in df.columns, df.columns |
| 67 | + """ |
| 68 | + if col_to_expand not in df.columns: |
| 69 | + print( |
| 70 | + "{} not found in the DataFrame columns. Have you computed ephemerides?".format( |
| 71 | + col_to_expand |
| 72 | + ) |
| 73 | + ) |
| 74 | + return df |
| 75 | + for col in COLUMNS: |
| 76 | + df = df.withColumn( |
| 77 | + sanitize_name(col), df["{}.{}".format(col_to_expand, sanitize_name(col))] |
| 78 | + ) |
| 79 | + df = df.drop(col_to_expand) |
| 80 | + return df |
| 81 | + |
| 82 | + |
| 83 | +@pandas_udf(MapType(StringType(), ArrayType(FloatType())), PandasUDFType.SCALAR) |
| 84 | +def extract_ztf_ephemerides_from_miriade(ssnamenr, cjd, uid, method): |
| 85 | + """Extract ephemerides for ZTF from Miriade |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + ssnamenr: pd.Series of str |
| 90 | + ZTF ssnamenr |
| 91 | + cjd: pd.Series of list of floats |
| 92 | + List of JD values |
| 93 | + uid: pd.Series of int |
| 94 | + Unique ID for each object |
| 95 | + method: pd.Series of str |
| 96 | + Method to compute ephemerides: `ephemcc` or `rest`. |
| 97 | + Use only the former on the Spark Cluster (local installation of ephemcc), |
| 98 | + otherwise use `rest` to call the ssodnet web service. |
| 99 | +
|
| 100 | + Returns |
| 101 | + ------- |
| 102 | + out: pd.Series of dictionaries of lists |
| 103 | +
|
| 104 | + Examples |
| 105 | + -------- |
| 106 | + >>> import pyspark.sql.functions as F |
| 107 | +
|
| 108 | + Basic ephemerides computation |
| 109 | + >>> path = "fink_utils/test_data/agg_benoit_julien_2024" |
| 110 | + >>> df_prev = spark.read.format("parquet").load(path) |
| 111 | +
|
| 112 | + >>> df_prev_ephem = df_prev.withColumn( |
| 113 | + ... "ephem", |
| 114 | + ... extract_ztf_ephemerides_from_miriade( |
| 115 | + ... "ssnamenr", |
| 116 | + ... "cjd", |
| 117 | + ... F.expr("uuid()"), |
| 118 | + ... F.lit("rest"))) |
| 119 | +
|
| 120 | + >>> df_prev_ephem = expand_columns(df_prev_ephem) |
| 121 | + >>> out = df_prev_ephem.select(["cjd", "Dobs"]).collect() |
| 122 | + >>> assert len(out[0]["cjd"]) == len(out[0]["Dobs"]) |
| 123 | + >>> assert len(out[1]["cjd"]) == len(out[1]["Dobs"]) |
| 124 | +
|
| 125 | + Aggregation of ephemerides |
| 126 | + >>> from fink_utils.sso.ssoft import aggregate_ztf_sso_data |
| 127 | + >>> path = "fink_utils/test_data/benoit_julien_2025/science" |
| 128 | + >>> df_new = aggregate_ztf_sso_data(year=2025, month=1, prefix_path=path) |
| 129 | +
|
| 130 | + >>> df_new_ephem = df_new.withColumn( |
| 131 | + ... "ephem", |
| 132 | + ... extract_ztf_ephemerides_from_miriade( |
| 133 | + ... "ssnamenr", |
| 134 | + ... "cjd", |
| 135 | + ... F.expr("uuid()"), |
| 136 | + ... F.lit("rest"))) |
| 137 | + >>> df_new_ephem = expand_columns(df_new_ephem) |
| 138 | + >>> out = df_new_ephem.select(["cjd", "SDSS:g"]).collect() |
| 139 | + >>> assert len(out[0]["cjd"]) == len(out[0]["SDSS:g"]) |
| 140 | +
|
| 141 | + Checking rolling add |
| 142 | + >>> from fink_utils.sso.ssoft import join_aggregated_sso_data |
| 143 | + >>> df_join = join_aggregated_sso_data(df_prev, df_new, on="ssnamenr") |
| 144 | + >>> df_join_ephem = df_join.withColumn( |
| 145 | + ... "ephem", |
| 146 | + ... extract_ztf_ephemerides_from_miriade( |
| 147 | + ... "ssnamenr", |
| 148 | + ... "cjd", |
| 149 | + ... F.expr("uuid()"), |
| 150 | + ... F.lit("rest"))) |
| 151 | + >>> df_join_ephem = expand_columns(df_join_ephem) |
| 152 | +
|
| 153 | + >>> df_join_ephem_bis = join_aggregated_sso_data(df_prev_ephem, df_new_ephem, on="ssnamenr") |
| 154 | + >>> out_1 = df_join_ephem.select(["Elong"]).collect() |
| 155 | + >>> out_2 = df_join_ephem_bis.select(["Elong"]).collect() |
| 156 | + >>> assert out_1 == out_2, (out_1, out_2) |
| 157 | + """ |
| 158 | + method_ = method.to_numpy()[0] |
| 159 | + out = [] |
| 160 | + for index, ssname in enumerate(ssnamenr.to_numpy()): |
| 161 | + if method_ == "ephemcc": |
| 162 | + # Hardcoded! |
| 163 | + parameters = { |
| 164 | + "outdir": "/tmp/ramdisk/spins", |
| 165 | + "runner_path": "/tmp/fink_run_ephemcc4.sh", |
| 166 | + "userconf": "/tmp/.eproc-4.3", |
| 167 | + "iofile": "/tmp/default-ephemcc-observation.xml", |
| 168 | + } |
| 169 | + ephems = query_miriade_ephemcc( |
| 170 | + ssname, |
| 171 | + cjd.to_numpy()[index], |
| 172 | + observer="I41", |
| 173 | + rplane="1", |
| 174 | + tcoor=5, |
| 175 | + shift=15.0, |
| 176 | + parameters=parameters, |
| 177 | + uid=uid.to_numpy()[index], |
| 178 | + return_json=True, |
| 179 | + ) |
| 180 | + else: |
| 181 | + ephems = query_miriade( |
| 182 | + ssname, |
| 183 | + cjd.to_numpy()[index], |
| 184 | + observer="I41", |
| 185 | + rplane="1", |
| 186 | + tcoor=5, |
| 187 | + shift=15.0, |
| 188 | + timeout=30, |
| 189 | + return_json=True, |
| 190 | + ) |
| 191 | + if ephems.get("data", None) is not None: |
| 192 | + # Remove any "." in name |
| 193 | + out.append({ |
| 194 | + sanitize_name(k): [dic[k] for dic in ephems["data"]] for k in COLUMNS |
| 195 | + }) |
| 196 | + else: |
| 197 | + # Not sure about that |
| 198 | + out.append({}) |
| 199 | + |
| 200 | + return pd.Series(out) |
| 201 | + |
| 202 | + |
| 203 | +if __name__ == "__main__": |
| 204 | + """Execute the unit test suite""" |
| 205 | + |
| 206 | + # Run the Spark test suite |
| 207 | + spark_unit_tests(globals()) |
0 commit comments