Skip to content

Commit 1c8e74a

Browse files
Rolling ephemerides computation (#165)
* PEP8 * Utilities for ephemerides computation * Update doc * Add utility to get first date of next month * Expand test suite * PEP8
1 parent 1476941 commit 1c8e74a

File tree

3 files changed

+264
-5
lines changed

3 files changed

+264
-5
lines changed

fink_utils/sso/ephem.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright 2025 AstroLab Software
2+
# Author: Julien Peloton
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Get ephemerides at scales"""
16+
17+
import pandas as pd
18+
19+
from pyspark.sql.functions import pandas_udf, PandasUDFType
20+
from pyspark.sql.types import MapType, StringType, FloatType, ArrayType
21+
22+
from fink_utils.sso.utils import query_miriade_ephemcc
23+
from fink_utils.sso.utils import query_miriade
24+
25+
from fink_utils.tester import spark_unit_tests
26+
27+
28+
COLUMNS = ["Dobs", "Dhelio", "SDSS:g", "SDSS:r", "Phase", "Elong."]
29+
30+
31+
def sanitize_name(col):
32+
"""Remove trailing '.' from names"""
33+
return col.replace(".", "")
34+
35+
36+
def expand_columns(df, col_to_expand="ephem"):
37+
"""Expand a MapType column into individual columns
38+
39+
Notes
40+
-----
41+
The operation will transform a dataframe with columns
42+
["toto", "container.col1", container.col2] to a dataframe with columns
43+
["toto", "col1", "col2"]
44+
Note that `col_to_expand` is dropped.
45+
46+
Parameters
47+
----------
48+
df: Spark DataFrame
49+
Spark DataFrame with the map column
50+
col_to_expand: str
51+
Name of the column to expand
52+
53+
Returns
54+
-------
55+
out: Spark DataFrame
56+
The expanded input DataFrame
57+
58+
Examples
59+
--------
60+
>>> pdf = pd.DataFrame({"a": [{"Dobs": 1, "Elong": 2}, {"Dobs": 10, "Elong": 20}]})
61+
>>> df = spark.createDataFrame(pdf)
62+
>>> assert "a" in df.columns, df.columns
63+
>>> assert "Dobs" not in df.columns, df.columns
64+
>>> df = expand_columns(df, col_to_expand="a")
65+
>>> assert "Dobs" in df.columns, df.columns
66+
>>> assert "a" not in df.columns, df.columns
67+
"""
68+
if col_to_expand not in df.columns:
69+
print(
70+
"{} not found in the DataFrame columns. Have you computed ephemerides?".format(
71+
col_to_expand
72+
)
73+
)
74+
return df
75+
for col in COLUMNS:
76+
df = df.withColumn(
77+
sanitize_name(col), df["{}.{}".format(col_to_expand, sanitize_name(col))]
78+
)
79+
df = df.drop(col_to_expand)
80+
return df
81+
82+
83+
@pandas_udf(MapType(StringType(), ArrayType(FloatType())), PandasUDFType.SCALAR)
84+
def extract_ztf_ephemerides_from_miriade(ssnamenr, cjd, uid, method):
85+
"""Extract ephemerides for ZTF from Miriade
86+
87+
Parameters
88+
----------
89+
ssnamenr: pd.Series of str
90+
ZTF ssnamenr
91+
cjd: pd.Series of list of floats
92+
List of JD values
93+
uid: pd.Series of int
94+
Unique ID for each object
95+
method: pd.Series of str
96+
Method to compute ephemerides: `ephemcc` or `rest`.
97+
Use only the former on the Spark Cluster (local installation of ephemcc),
98+
otherwise use `rest` to call the ssodnet web service.
99+
100+
Returns
101+
-------
102+
out: pd.Series of dictionaries of lists
103+
104+
Examples
105+
--------
106+
>>> import pyspark.sql.functions as F
107+
108+
Basic ephemerides computation
109+
>>> path = "fink_utils/test_data/agg_benoit_julien_2024"
110+
>>> df_prev = spark.read.format("parquet").load(path)
111+
112+
>>> df_prev_ephem = df_prev.withColumn(
113+
... "ephem",
114+
... extract_ztf_ephemerides_from_miriade(
115+
... "ssnamenr",
116+
... "cjd",
117+
... F.expr("uuid()"),
118+
... F.lit("rest")))
119+
120+
>>> df_prev_ephem = expand_columns(df_prev_ephem)
121+
>>> out = df_prev_ephem.select(["cjd", "Dobs"]).collect()
122+
>>> assert len(out[0]["cjd"]) == len(out[0]["Dobs"])
123+
>>> assert len(out[1]["cjd"]) == len(out[1]["Dobs"])
124+
125+
Aggregation of ephemerides
126+
>>> from fink_utils.sso.ssoft import aggregate_ztf_sso_data
127+
>>> path = "fink_utils/test_data/benoit_julien_2025/science"
128+
>>> df_new = aggregate_ztf_sso_data(year=2025, month=1, prefix_path=path)
129+
130+
>>> df_new_ephem = df_new.withColumn(
131+
... "ephem",
132+
... extract_ztf_ephemerides_from_miriade(
133+
... "ssnamenr",
134+
... "cjd",
135+
... F.expr("uuid()"),
136+
... F.lit("rest")))
137+
>>> df_new_ephem = expand_columns(df_new_ephem)
138+
>>> out = df_new_ephem.select(["cjd", "SDSS:g"]).collect()
139+
>>> assert len(out[0]["cjd"]) == len(out[0]["SDSS:g"])
140+
141+
Checking rolling add
142+
>>> from fink_utils.sso.ssoft import join_aggregated_sso_data
143+
>>> df_join = join_aggregated_sso_data(df_prev, df_new, on="ssnamenr")
144+
>>> df_join_ephem = df_join.withColumn(
145+
... "ephem",
146+
... extract_ztf_ephemerides_from_miriade(
147+
... "ssnamenr",
148+
... "cjd",
149+
... F.expr("uuid()"),
150+
... F.lit("rest")))
151+
>>> df_join_ephem = expand_columns(df_join_ephem)
152+
153+
>>> df_join_ephem_bis = join_aggregated_sso_data(df_prev_ephem, df_new_ephem, on="ssnamenr")
154+
>>> out_1 = df_join_ephem.select(["Elong"]).collect()
155+
>>> out_2 = df_join_ephem_bis.select(["Elong"]).collect()
156+
>>> assert out_1 == out_2, (out_1, out_2)
157+
"""
158+
method_ = method.to_numpy()[0]
159+
out = []
160+
for index, ssname in enumerate(ssnamenr.to_numpy()):
161+
if method_ == "ephemcc":
162+
# Hardcoded!
163+
parameters = {
164+
"outdir": "/tmp/ramdisk/spins",
165+
"runner_path": "/tmp/fink_run_ephemcc4.sh",
166+
"userconf": "/tmp/.eproc-4.3",
167+
"iofile": "/tmp/default-ephemcc-observation.xml",
168+
}
169+
ephems = query_miriade_ephemcc(
170+
ssname,
171+
cjd.to_numpy()[index],
172+
observer="I41",
173+
rplane="1",
174+
tcoor=5,
175+
shift=15.0,
176+
parameters=parameters,
177+
uid=uid.to_numpy()[index],
178+
return_json=True,
179+
)
180+
else:
181+
ephems = query_miriade(
182+
ssname,
183+
cjd.to_numpy()[index],
184+
observer="I41",
185+
rplane="1",
186+
tcoor=5,
187+
shift=15.0,
188+
timeout=30,
189+
return_json=True,
190+
)
191+
if ephems.get("data", None) is not None:
192+
# Remove any "." in name
193+
out.append({
194+
sanitize_name(k): [dic[k] for dic in ephems["data"]] for k in COLUMNS
195+
})
196+
else:
197+
# Not sure about that
198+
out.append({})
199+
200+
return pd.Series(out)
201+
202+
203+
if __name__ == "__main__":
204+
"""Execute the unit test suite"""
205+
206+
# Run the Spark test suite
207+
spark_unit_tests(globals())

fink_utils/sso/ssoft.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,34 @@ def retrieve_last_date_of_previous_month(mydate):
545545
return last_month
546546

547547

548+
def retrieve_first_date_of_next_month(mydate):
549+
"""Given a date, retrieve the first date from next month
550+
551+
Parameters
552+
----------
553+
mydate: datetime
554+
Input date
555+
556+
Returns
557+
-------
558+
out: datetime
559+
Last date from previous month according to `mydate`
560+
561+
Examples
562+
--------
563+
>>> mydate = datetime.date(year=2025, month=4, day=5)
564+
>>> out = retrieve_first_date_of_next_month(mydate)
565+
>>> assert out.strftime("%m") == "05"
566+
>>> assert out.day == 1
567+
568+
>>> mydate = datetime.date(year=2025, month=12, day=14)
569+
>>> out = retrieve_first_date_of_next_month(mydate)
570+
>>> assert out.month == 1
571+
>>> assert out.year == 2026
572+
"""
573+
return (mydate.replace(day=1) + datetime.timedelta(days=32)).replace(day=1)
574+
575+
548576
if __name__ == "__main__":
549577
"""Execute the unit test suite"""
550578

fink_utils/sso/utils.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,14 @@
3131

3232

3333
def query_miriade(
34-
ident, jd, observer="I41", rplane="1", tcoor=5, shift=15.0, timeout=30
34+
ident,
35+
jd,
36+
observer="I41",
37+
rplane="1",
38+
tcoor=5,
39+
shift=15.0,
40+
timeout=30,
41+
return_json=False,
3542
):
3643
"""Gets asteroid or comet ephemerides from IMCCE Miriade for a suite of JD for a single SSO
3744
@@ -60,6 +67,9 @@ def query_miriade(
6067
Default is 15 seconds which is half of the exposure time for ZTF.
6168
timeout: int
6269
Timeout in seconds. Default is 30.
70+
return_json: bool
71+
If True, return the JSON payload. Otherwise, returns
72+
a pandas DataFrame. Default is False.
6373
6474
Returns
6575
-------
@@ -101,10 +111,15 @@ def query_miriade(
101111
try:
102112
r = requests.post(url, params=params, files=files, timeout=timeout)
103113
except requests.exceptions.ReadTimeout:
114+
if return_json:
115+
return {}
104116
return pd.DataFrame()
105117

106118
j = r.json()
107119

120+
if return_json:
121+
return j
122+
108123
# Read JSON response
109124
try:
110125
ephem = pd.DataFrame.from_dict(j["data"])
@@ -114,7 +129,7 @@ def query_miriade(
114129
return ephem
115130

116131

117-
def query_miriade_epehemcc(
132+
def query_miriade_ephemcc(
118133
ident,
119134
jd,
120135
observer="I41",
@@ -123,6 +138,7 @@ def query_miriade_epehemcc(
123138
shift=15.0,
124139
parameters=None,
125140
uid=None,
141+
return_json=False,
126142
):
127143
"""Gets asteroid or comet ephemerides from IMCCE Miriade for a suite of JD for a single SSO
128144
@@ -154,6 +170,9 @@ def query_miriade_epehemcc(
154170
uid: int, optional
155171
If specified, ID used to write files on disk. Must be unique for each object.
156172
Default is None, i.e. randomly sampled from U(0, 1e7)
173+
return_json: bool
174+
If True, return the JSON payload. Otherwise, returns
175+
a pandas DataFrame. Default is False.
157176
158177
Returns
159178
-------
@@ -197,12 +216,17 @@ def query_miriade_epehemcc(
197216

198217
# clean date file
199218
os.remove(date_path)
200-
219+
if return_json:
220+
return {}
201221
return pd.DataFrame()
202222

203223
# read the data from disk and return
204224
with open(ephem_path, "r") as f:
205225
data = json.load(f)
226+
227+
if return_json:
228+
return data
229+
206230
ephem = pd.DataFrame(data["data"], columns=data["datacol"].keys())
207231

208232
# clean tmp files
@@ -283,7 +307,7 @@ def get_miriade_data(
283307
timeout=timeout,
284308
)
285309
elif method == "ephemcc":
286-
eph = query_miriade_epehemcc(
310+
eph = query_miriade_ephemcc(
287311
str(ssnamenr),
288312
pdf_sub["i:jd"],
289313
observer=observer,
@@ -317,7 +341,7 @@ def get_miriade_data(
317341
timeout=timeout,
318342
)
319343
elif method == "ephemcc":
320-
eph_ec = query_miriade_epehemcc(
344+
eph_ec = query_miriade_ephemcc(
321345
str(ssnamenr),
322346
pdf_sub["i:jd"],
323347
observer=observer,

0 commit comments

Comments
 (0)