Skip to content

Commit 78c57de

Browse files
committed
Ready to fit models! ✨
1 parent 71ecc4a commit 78c57de

File tree

2 files changed

+93
-7
lines changed

2 files changed

+93
-7
lines changed

Irradiances_ratios/E_ratio_script.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
import numpy as np
1313
import pandas as pd
14-
15-
from datetime import datetime
14+
# import matplotlib.pyplot as plt
1615

1716
# Matrix of values to test
1817
# Atmosphere characterization required params
@@ -27,11 +26,21 @@
2726
}
2827

2928
# what do we want to plot E_λ<λ₀/E against? (None = default behaviour)
30-
plot_keys = "datetime"
29+
plot_keys = None
30+
31+
# model function
32+
model_inputs = ["relative_airmass", "aerosol_turbidity_500nm"]
33+
p0 = None # [0.2, 0.1]
34+
35+
36+
def model(xdata, c0, c1): # use this func as model template
37+
r_am, aod500 = xdata
38+
return c0 * r_am + c1 * aod500
39+
3140

3241
bench = MR_E_ratio(
3342
datetimes=pd.date_range(
34-
"2023-11-27T00", "2023-11-28T00", freq=pd.Timedelta(minutes=1)
43+
"2023-11-27T00", "2023-11-28T00", freq=pd.Timedelta(minutes=15)
3544
)
3645
)
3746

@@ -40,14 +49,20 @@
4049
bench.cutoff_lambda = LAMBDA0["monosi"] # == polysi
4150
bench.simulate_from_product(**spectrl2_generator_input)
4251
bench.plot_results(plot_keys=plot_keys)
52+
optim_result = bench.optimization_from_model(
53+
model=model, model_inputs=model_inputs, p0=p0
54+
)
4355
bench.times_summary()
4456

57+
# TODO: PLOT RESULTS & MODEL PREDICTION
58+
4559
# %%
4660
# Test with asi cutoff wavelength
4761
bench.reset_simulation_state()
4862
bench.cutoff_lambda = LAMBDA0["asi"]
4963
bench.simulate_from_product(**spectrl2_generator_input)
5064
bench.plot_results(plot_keys=plot_keys)
65+
bench.optimization_from_model(model=model, model_inputs=model_inputs, p0=p0)
5166
bench.times_summary()
5267

5368
# %%

Irradiances_ratios/spectrl2_E_ratio_bench.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
import numpy as np
1616
import pandas as pd
1717
import matplotlib.pyplot as plt
18+
from scipy.optimize import curve_fit
1819

1920
from itertools import product
2021
from functools import partial
2122
from datetime import datetime
2223
from time import time
24+
from typing import Callable
2325

2426

2527
class MR_E_ratio:
@@ -118,11 +120,11 @@ def simulation_prerun(self):
118120
self.solpos["azimuth"],
119121
)
120122
self.time_params = {
121-
"apparent_zenith": self.solpos["apparent_zenith"],
122-
"aoi": self.aoi,
123+
"apparent_zenith": self.solpos["apparent_zenith"].to_numpy(),
124+
"aoi": self.aoi.to_numpy(),
123125
"relative_airmass": self.locus.get_airmass(solar_position=self.solpos)[
124126
"airmass_relative"
125-
],
127+
].to_numpy(),
126128
"dayofyear": np.fromiter(
127129
map(day_of_year, self.datetimes), dtype=np.float64
128130
),
@@ -306,3 +308,72 @@ def plot_results(
306308
plt.close()
307309

308310
self.processing_time["plot_results"] = time() - start_time
311+
312+
def optimization_from_model(
313+
self, model: Callable = None, model_inputs: tuple = None, **kwargs
314+
):
315+
"""
316+
Optimize a model to fit generated data.
317+
318+
Parameters
319+
----------
320+
model : Callable
321+
Function with the model to be optimised.
322+
model_inputs : str or iterable of str
323+
Order and parameters of ``model``. Must be any of:
324+
* ``datetime``
325+
* ``apparent_zenith``, ``aoi``, ``relative_airmass`` or ``dayofyear``
326+
* any parameter name provided to ``simulate_from_product``
327+
**kwargs :
328+
Redirected to ``scipy.optimize.curve_fit``.
329+
330+
Returns
331+
-------
332+
``scipy.optimize.curve_fit``'s return values
333+
"""
334+
start_time = time() # Initialize start time of block
335+
336+
if isinstance(model_inputs, str):
337+
model_inputs = (model_inputs,)
338+
339+
# number of inputs from user: n-left-most columns
340+
n_inputs = len(self.input_keys)
341+
# Fitting data
342+
ydata = self.results.iloc[:, n_inputs:].to_numpy().flatten()
343+
# Prepare input vector
344+
xdata = [] # length of each value must be
345+
dates_len = len(self.datetimes)
346+
try:
347+
for var_name in model_inputs:
348+
# broadcast all inputs to match ydata
349+
if var_name in self.input_keys:
350+
xdata.append(self.results[var_name].to_numpy().repeat(dates_len))
351+
elif var_name in self.time_params.keys():
352+
xdata.append(
353+
np.tile(self.time_params[var_name], self.results.shape[0])
354+
)
355+
elif var_name in {"datetime"}:
356+
xdata.append(
357+
np.tile(self.datetimes.to_numpy(), self.results.shape[0])
358+
)
359+
else:
360+
raise ValueError(f"'{var_name}' is not a valid parameter name!")
361+
362+
except TypeError:
363+
raise TypeError(
364+
"Provide a valid model input names vector. Must be iterable"
365+
+ " of strings, and that input will be provided to 'model'"
366+
+ f" in the same order.\nYou provided {model_inputs}"
367+
)
368+
369+
## This is kept here for debug purposes: check valid representation as 1D arrays
370+
# fig, axs = plt.subplots(len(model_inputs))
371+
# for i, name in enumerate(model_inputs):
372+
# axs[i].set_title(name)
373+
# axs[i].scatter(xdata[i], ydata)
374+
# plt.show()
375+
376+
curve_fit_results = curve_fit(model, xdata, ydata, nan_policy="omit", **kwargs)
377+
378+
self.processing_time["optimization_from_model"] = time() - start_time
379+
return curve_fit_results

0 commit comments

Comments
 (0)