Skip to content

Commit 9b97a98

Browse files
committed
add detailed comments
1 parent 2969ac7 commit 9b97a98

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

deepexo/rockyplanet.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,26 @@
1616

1717

1818
class RockyPlanet:
19+
"""A class for characterizing the interior structure of rocky exoplanets."""
1920
def __init__(self):
2021
# print('init RockyPlanet')
21-
self.model_a = load_model(os.path.join(model_path, "model_a1.h5"), custom_objects={
22+
self.model_a = load_model(os.path.join(model_path, "model_a.h5"), custom_objects={
2223
'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES),
2324
"mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}, compile=False)
24-
self.model_a_scaler = joblib.load(os.path.join(model_path, "model_a1_scaler.save"))
25+
self.model_a_scaler = joblib.load(os.path.join(model_path, "model_a_scaler.save"))
2526
self.model_b = load_model(os.path.join(model_path, "model_b.h5"), custom_objects={
2627
'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES),
2728
"mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}, compile=False)
2829
self.model_b_scaler = joblib.load(os.path.join(model_path, "model_b_scaler.save"))
2930

3031
def predict(self, planet_params):
31-
"""Predictes the Water radial fraction, Mantle radial fraction, Core radial fraction, Core mass fraction,
32+
"""Predicts the Water radial fraction, Mantle radial fraction, Core radial fraction, Core mass fraction,
3233
CMB pressure and CMB temperature for the given planetary parameters in terms of planetary mass M [M_Earth],
3334
radius [R_Earth], bulk Fe/(Mg + Si) (molar ratio), and tide Love number k2.
3435
3536
Args:
3637
planet_params (list): A list of planetary parameters in the order of [M, R, k2] or [M, R, cFeMg, k2].
37-
Retures:
38+
Returns:
3839
pred: contains parameters for distributions, not actual points on the graph.
3940
"""
4041
if len(planet_params) == 3:
@@ -57,7 +58,7 @@ def plot(self, pred, save=False, filename="pred.png"):
5758
save (bool, optional): Defaults to False. If True, saves the plot to a file.
5859
filename (str, optional): Defaults to "". The filename to save the plot to.
5960
Returns:
60-
None
61+
None or saves the plot to a file.
6162
"""
6263
(y_min1, y_max1, y_min2, y_max2, y_min3, y_max3,
6364
y_min4, y_max4, y_min5, y_max5, y_min6, y_max6) = \
@@ -73,7 +74,7 @@ def plot(self, pred, save=False, filename="pred.png"):
7374
locals()['mus' + str(m)].append(mus[0][n * OUTPUT_DIMS + m])
7475
locals()['sigs' + str(m)].append(sigs[0][n * OUTPUT_DIMS + m])
7576
x_max = [1, 1, 1, 1, 1, 1]
76-
x_maxlabels = [
77+
x_max_labels = [
7778
y_min1 + (y_max1 - y_min1) * x_max[0],
7879
y_min2 + (y_max2 - y_min2) * x_max[1],
7980
y_min3 + (y_max3 - y_min3) * x_max[2],
@@ -90,17 +91,17 @@ def plot(self, pred, save=False, filename="pred.png"):
9091
tcmb = [] # CMB pressure
9192

9293
for x1 in np.arange(0.02, 0.15, 0.04):
93-
wrf.append((x1 - y_min1) / (x_maxlabels[0] - y_min1) * x_max[0])
94+
wrf.append((x1 - y_min1) / (x_max_labels[0] - y_min1) * x_max[0])
9495
for x2 in np.arange(0.2, 1, 0.2):
95-
mrf.append((x2 - y_min2) / (x_maxlabels[1] - y_min2) * x_max[1])
96+
mrf.append((x2 - y_min2) / (x_max_labels[1] - y_min2) * x_max[1])
9697
for x3 in np.arange(0.1, 0.9, 0.2):
97-
crf.append((x3 - y_min3) / (x_maxlabels[2] - y_min3) * x_max[2])
98+
crf.append((x3 - y_min3) / (x_max_labels[2] - y_min3) * x_max[2])
9899
for x4 in np.arange(0.1, 0.8, 0.2):
99-
cmf.append((x4 - y_min4) / (x_maxlabels[3] - y_min4) * x_max[3])
100+
cmf.append((x4 - y_min4) / (x_max_labels[3] - y_min4) * x_max[3])
100101
for x5 in np.arange(200, 2000, 400):
101-
pcmb.append((x5 - y_min5) / (x_maxlabels[4] - y_min5) * x_max[4])
102+
pcmb.append((x5 - y_min5) / (x_max_labels[4] - y_min5) * x_max[4])
102103
for x6 in np.arange(2000, 6000, 1000):
103-
tcmb.append((x6 - y_min6) / (x_maxlabels[5] - y_min6) * x_max[5])
104+
tcmb.append((x6 - y_min6) / (x_max_labels[5] - y_min6) * x_max[5])
104105

105106
xticklabels = [[round(x, 2) for x in np.arange(0.02, 0.15, 0.04)],
106107
[round(x, 2) for x in np.arange(0.2, 1, 0.2)],
@@ -110,7 +111,6 @@ def plot(self, pred, save=False, filename="pred.png"):
110111
[round(x, 2) for x in np.arange(2000, 6000, 1000)],
111112
]
112113
xticks = [wrf, mrf, crf, cmf, pcmb, tcmb]
113-
114114
colors = [
115115
"steelblue",
116116
"#EA7D1A",
@@ -142,22 +142,17 @@ def plot(self, pred, save=False, filename="pred.png"):
142142
y_label,
143143
GMM_PDF,
144144
color=colors[i],
145-
# label=label,
146145
lw=2,
147146
zorder=10,
148147
)
149148
ax.set_xlim(0, x_max[i])
150149
ax.set_ylim(bottom=0)
151-
152150
ax.xaxis.set_minor_locator(AutoMinorLocator())
153151
ax.yaxis.set_minor_locator(AutoMinorLocator())
154-
# print(xticks[i])
155-
# xticklabels[i]
156152
ax.set_xticks(xticks[i])
157153
ax.set_xticklabels(xticklabels[i])
158154
ax.set_xlabel(predict_label[i])
159155
ax.set_ylabel("Probability density")
160-
161156
if save:
162157
return plt.savefig(filename, dpi=300)
163158
else:

0 commit comments

Comments
 (0)