16
16
17
17
18
18
class RockyPlanet :
19
+ """A class for characterizing the interior structure of rocky exoplanets."""
19
20
def __init__ (self ):
20
21
# print('init RockyPlanet')
21
- self .model_a = load_model (os .path .join (model_path , "model_a1 .h5" ), custom_objects = {
22
+ self .model_a = load_model (os .path .join (model_path , "model_a .h5" ), custom_objects = {
22
23
'MDN' : mdn .MDN (OUTPUT_DIMS , N_MIXES ),
23
24
"mdn_loss_func" : mdn .get_mixture_loss_func (OUTPUT_DIMS , N_MIXES )}, compile = False )
24
- self .model_a_scaler = joblib .load (os .path .join (model_path , "model_a1_scaler .save" ))
25
+ self .model_a_scaler = joblib .load (os .path .join (model_path , "model_a_scaler .save" ))
25
26
self .model_b = load_model (os .path .join (model_path , "model_b.h5" ), custom_objects = {
26
27
'MDN' : mdn .MDN (OUTPUT_DIMS , N_MIXES ),
27
28
"mdn_loss_func" : mdn .get_mixture_loss_func (OUTPUT_DIMS , N_MIXES )}, compile = False )
28
29
self .model_b_scaler = joblib .load (os .path .join (model_path , "model_b_scaler.save" ))
29
30
30
31
def predict (self , planet_params ):
31
- """Predictes the Water radial fraction, Mantle radial fraction, Core radial fraction, Core mass fraction,
32
+ """Predicts the Water radial fraction, Mantle radial fraction, Core radial fraction, Core mass fraction,
32
33
CMB pressure and CMB temperature for the given planetary parameters in terms of planetary mass M [M_Earth],
33
34
radius [R_Earth], bulk Fe/(Mg + Si) (molar ratio), and tide Love number k2.
34
35
35
36
Args:
36
37
planet_params (list): A list of planetary parameters in the order of [M, R, k2] or [M, R, cFeMg, k2].
37
- Retures :
38
+ Returns :
38
39
pred: contains parameters for distributions, not actual points on the graph.
39
40
"""
40
41
if len (planet_params ) == 3 :
@@ -57,7 +58,7 @@ def plot(self, pred, save=False, filename="pred.png"):
57
58
save (bool, optional): Defaults to False. If True, saves the plot to a file.
58
59
filename (str, optional): Defaults to "". The filename to save the plot to.
59
60
Returns:
60
- None
61
+ None or saves the plot to a file.
61
62
"""
62
63
(y_min1 , y_max1 , y_min2 , y_max2 , y_min3 , y_max3 ,
63
64
y_min4 , y_max4 , y_min5 , y_max5 , y_min6 , y_max6 ) = \
@@ -73,7 +74,7 @@ def plot(self, pred, save=False, filename="pred.png"):
73
74
locals ()['mus' + str (m )].append (mus [0 ][n * OUTPUT_DIMS + m ])
74
75
locals ()['sigs' + str (m )].append (sigs [0 ][n * OUTPUT_DIMS + m ])
75
76
x_max = [1 , 1 , 1 , 1 , 1 , 1 ]
76
- x_maxlabels = [
77
+ x_max_labels = [
77
78
y_min1 + (y_max1 - y_min1 ) * x_max [0 ],
78
79
y_min2 + (y_max2 - y_min2 ) * x_max [1 ],
79
80
y_min3 + (y_max3 - y_min3 ) * x_max [2 ],
@@ -90,17 +91,17 @@ def plot(self, pred, save=False, filename="pred.png"):
90
91
tcmb = [] # CMB pressure
91
92
92
93
for x1 in np .arange (0.02 , 0.15 , 0.04 ):
93
- wrf .append ((x1 - y_min1 ) / (x_maxlabels [0 ] - y_min1 ) * x_max [0 ])
94
+ wrf .append ((x1 - y_min1 ) / (x_max_labels [0 ] - y_min1 ) * x_max [0 ])
94
95
for x2 in np .arange (0.2 , 1 , 0.2 ):
95
- mrf .append ((x2 - y_min2 ) / (x_maxlabels [1 ] - y_min2 ) * x_max [1 ])
96
+ mrf .append ((x2 - y_min2 ) / (x_max_labels [1 ] - y_min2 ) * x_max [1 ])
96
97
for x3 in np .arange (0.1 , 0.9 , 0.2 ):
97
- crf .append ((x3 - y_min3 ) / (x_maxlabels [2 ] - y_min3 ) * x_max [2 ])
98
+ crf .append ((x3 - y_min3 ) / (x_max_labels [2 ] - y_min3 ) * x_max [2 ])
98
99
for x4 in np .arange (0.1 , 0.8 , 0.2 ):
99
- cmf .append ((x4 - y_min4 ) / (x_maxlabels [3 ] - y_min4 ) * x_max [3 ])
100
+ cmf .append ((x4 - y_min4 ) / (x_max_labels [3 ] - y_min4 ) * x_max [3 ])
100
101
for x5 in np .arange (200 , 2000 , 400 ):
101
- pcmb .append ((x5 - y_min5 ) / (x_maxlabels [4 ] - y_min5 ) * x_max [4 ])
102
+ pcmb .append ((x5 - y_min5 ) / (x_max_labels [4 ] - y_min5 ) * x_max [4 ])
102
103
for x6 in np .arange (2000 , 6000 , 1000 ):
103
- tcmb .append ((x6 - y_min6 ) / (x_maxlabels [5 ] - y_min6 ) * x_max [5 ])
104
+ tcmb .append ((x6 - y_min6 ) / (x_max_labels [5 ] - y_min6 ) * x_max [5 ])
104
105
105
106
xticklabels = [[round (x , 2 ) for x in np .arange (0.02 , 0.15 , 0.04 )],
106
107
[round (x , 2 ) for x in np .arange (0.2 , 1 , 0.2 )],
@@ -110,7 +111,6 @@ def plot(self, pred, save=False, filename="pred.png"):
110
111
[round (x , 2 ) for x in np .arange (2000 , 6000 , 1000 )],
111
112
]
112
113
xticks = [wrf , mrf , crf , cmf , pcmb , tcmb ]
113
-
114
114
colors = [
115
115
"steelblue" ,
116
116
"#EA7D1A" ,
@@ -142,22 +142,17 @@ def plot(self, pred, save=False, filename="pred.png"):
142
142
y_label ,
143
143
GMM_PDF ,
144
144
color = colors [i ],
145
- # label=label,
146
145
lw = 2 ,
147
146
zorder = 10 ,
148
147
)
149
148
ax .set_xlim (0 , x_max [i ])
150
149
ax .set_ylim (bottom = 0 )
151
-
152
150
ax .xaxis .set_minor_locator (AutoMinorLocator ())
153
151
ax .yaxis .set_minor_locator (AutoMinorLocator ())
154
- # print(xticks[i])
155
- # xticklabels[i]
156
152
ax .set_xticks (xticks [i ])
157
153
ax .set_xticklabels (xticklabels [i ])
158
154
ax .set_xlabel (predict_label [i ])
159
155
ax .set_ylabel ("Probability density" )
160
-
161
156
if save :
162
157
return plt .savefig (filename , dpi = 300 )
163
158
else :
0 commit comments