|
3 | 3 | from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
|
4 | 4 | import numpy as np
|
5 | 5 | import math
|
| 6 | +import functools |
6 | 7 | import mdn
|
7 | 8 | import joblib
|
8 | 9 | import os
|
|
14 | 15 | # print(os.getcwd())
|
15 | 16 | model_path = os.path.join(os.getcwd(), "deepexo/model")
|
16 | 17 |
|
| 18 | +def lazy_property(fn): |
| 19 | + attr_name = "_lazy_" + fn.__name__ |
17 | 20 |
|
| 21 | + @property |
| 22 | + @functools.wraps(fn) |
| 23 | + def _lazy_property(self): |
| 24 | + if not hasattr(self, attr_name): |
| 25 | + setattr(self, attr_name, fn(self)) |
| 26 | + return getattr(self, attr_name) |
| 27 | + |
| 28 | + return _lazy_property |
18 | 29 | class RockyPlanet:
|
19 | 30 | """A class for characterizing the interior structure of rocky exoplanets."""
|
20 | 31 | def __init__(self):
|
21 |
| - # print('init RockyPlanet') |
22 |
| - self.model_a = load_model(os.path.join(model_path, "model_a.h5"), custom_objects={ |
| 32 | + pass |
| 33 | + |
| 34 | + @lazy_property |
| 35 | + def model_a(self): |
| 36 | + return load_model(os.path.join(model_path, "model_a.h5"), custom_objects={ |
23 | 37 | 'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES),
|
24 | 38 | "mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}, compile=False)
|
25 |
| - self.model_a_scaler = joblib.load(os.path.join(model_path, "model_a_scaler.save")) |
26 |
| - self.model_b = load_model(os.path.join(model_path, "model_b.h5"), custom_objects={ |
| 39 | + |
| 40 | + @lazy_property |
| 41 | + def model_a_scaler(self): |
| 42 | + return joblib.load(os.path.join(model_path, "model_a_scaler.save")) |
| 43 | + |
| 44 | + @lazy_property |
| 45 | + def model_b(self): |
| 46 | + return load_model(os.path.join(model_path, "model_b.h5"), custom_objects={ |
27 | 47 | 'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES),
|
28 | 48 | "mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}, compile=False)
|
29 |
| - self.model_b_scaler = joblib.load(os.path.join(model_path, "model_b_scaler.save")) |
30 | 49 |
|
| 50 | + @lazy_property |
| 51 | + def model_b_scaler(self): |
| 52 | + return joblib.load(os.path.join(model_path, "model_b_scaler.save")) |
31 | 53 | def predict(self, planet_params: object) -> object:
|
32 | 54 | """Predicts the Water radial fraction, Mantle radial fraction, Core radial fraction, Core mass fraction,
|
33 | 55 | CMB pressure and CMB temperature for the given planetary parameters in terms of planetary mass M [M_Earth],
|
@@ -60,6 +82,8 @@ def plot(self, pred: object, save: object = False, filename: object = "pred.png"
|
60 | 82 | Returns:
|
61 | 83 | None or saves the plot to a file.
|
62 | 84 | """
|
| 85 | + print("###############################################") |
| 86 | + print("Plotting...") |
63 | 87 | (y_min1, y_max1, y_min2, y_max2, y_min3, y_max3,
|
64 | 88 | y_min4, y_max4, y_min5, y_max5, y_min6, y_max6) = \
|
65 | 89 | 0.00015137, 0.145835, 0.127618, 0.973427, 0.00787023, 0.799449, 1.17976e-06, 0.699986, 10.7182, 1999.49, 1689.37, 5673.87
|
@@ -154,6 +178,8 @@ def plot(self, pred: object, save: object = False, filename: object = "pred.png"
|
154 | 178 | ax.set_xlabel(predict_label[i])
|
155 | 179 | ax.set_ylabel("Probability density")
|
156 | 180 | if save:
|
| 181 | + print("Saving figure to {}".format(filename)) |
157 | 182 | return plt.savefig(filename, dpi=300)
|
158 | 183 | else:
|
| 184 | + print("Showing figure") |
159 | 185 | return plt.show()
|
0 commit comments