Skip to content

Commit d6dc5b2

Browse files
committed
Changes
1 parent 867ef33 commit d6dc5b2

File tree

7 files changed

+252
-37
lines changed

7 files changed

+252
-37
lines changed

MDN_two_planets_prediction.ipynb

Lines changed: 113 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -53,24 +53,7 @@
5353
},
5454
"id": "OHSSb5g_iWsK"
5555
},
56-
"outputs": [
57-
{
58-
"ename": "ModuleNotFoundError",
59-
"evalue": "No module named 'sklearn'",
60-
"output_type": "error",
61-
"traceback": [
62-
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
63-
"\u001B[1;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
64-
"\u001B[1;32m~\\AppData\\Local\\Temp\\ipykernel_19692\\2450612377.py\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 22\u001B[0m \u001B[0mcompile\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 23\u001B[0m )\n\u001B[1;32m---> 24\u001B[1;33m \u001B[0minput_scaler\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mjoblib\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"MDN_MRFe(Mg+Si)k2_Xscaler_20220917.save\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 25\u001B[0m \u001B[0moutput_scaler\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mjoblib\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"MDN_MRFe(Mg+Si)k2_yscaler_20220917.save\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
65-
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\site-packages\\joblib\\numpy_pickle.py\u001B[0m in \u001B[0;36mload\u001B[1;34m(filename, mmap_mode)\u001B[0m\n\u001B[0;32m 603\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mload_compatibility\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mfobj\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 604\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 605\u001B[1;33m \u001B[0mobj\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0m_unpickle\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mfobj\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mfilename\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mmmap_mode\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 606\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 607\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mobj\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
66-
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\site-packages\\joblib\\numpy_pickle.py\u001B[0m in \u001B[0;36m_unpickle\u001B[1;34m(fobj, filename, mmap_mode)\u001B[0m\n\u001B[0;32m 527\u001B[0m \u001B[0mobj\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 528\u001B[0m \u001B[1;32mtry\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 529\u001B[1;33m \u001B[0mobj\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0munpickler\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 530\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0munpickler\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mcompat_mode\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 531\u001B[0m warnings.warn(\"The file '%s' has been generated with a \"\n",
67-
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\pickle.py\u001B[0m in \u001B[0;36mload\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 1086\u001B[0m \u001B[1;32mraise\u001B[0m \u001B[0mEOFError\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1087\u001B[0m \u001B[1;32massert\u001B[0m \u001B[0misinstance\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkey\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mbytes_types\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1088\u001B[1;33m \u001B[0mdispatch\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mkey\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;36m0\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 1089\u001B[0m \u001B[1;32mexcept\u001B[0m \u001B[0m_Stop\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0mstopinst\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1090\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mstopinst\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mvalue\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
68-
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\pickle.py\u001B[0m in \u001B[0;36mload_global\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 1374\u001B[0m \u001B[0mmodule\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mreadline\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m-\u001B[0m\u001B[1;36m1\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mdecode\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"utf-8\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1375\u001B[0m \u001B[0mname\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mreadline\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m-\u001B[0m\u001B[1;36m1\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mdecode\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"utf-8\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1376\u001B[1;33m \u001B[0mklass\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mfind_class\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodule\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mname\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 1377\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mappend\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mklass\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1378\u001B[0m \u001B[0mdispatch\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mGLOBAL\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;36m0\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mload_global\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
69-
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\pickle.py\u001B[0m in \u001B[0;36mfind_class\u001B[1;34m(self, module, name)\u001B[0m\n\u001B[0;32m 1424\u001B[0m \u001B[1;32melif\u001B[0m \u001B[0mmodule\u001B[0m \u001B[1;32min\u001B[0m \u001B[0m_compat_pickle\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mIMPORT_MAPPING\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1425\u001B[0m \u001B[0mmodule\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0m_compat_pickle\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mIMPORT_MAPPING\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mmodule\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1426\u001B[1;33m \u001B[0m__import__\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodule\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mlevel\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;36m0\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 1427\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mproto\u001B[0m \u001B[1;33m>=\u001B[0m \u001B[1;36m4\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1428\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0m_getattribute\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0msys\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmodules\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mmodule\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mname\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;36m0\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
70-
"\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'sklearn'"
71-
]
72-
}
73-
],
56+
"outputs": [],
7457
"source": [
7558
"input_parameters = [\n",
7659
" 'Mass', \n",
@@ -141,7 +124,27 @@
141124
},
142125
{
143126
"cell_type": "code",
144-
"execution_count": 4,
127+
"execution_count": 8,
128+
"metadata": {},
129+
"outputs": [
130+
{
131+
"data": {
132+
"text/plain": [
133+
"(1, 4)"
134+
]
135+
},
136+
"execution_count": 8,
137+
"metadata": {},
138+
"output_type": "execute_result"
139+
}
140+
],
141+
"source": [
142+
"scaled_input.shape"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": 5,
145148
"metadata": {
146149
"ExecuteTime": {
147150
"end_time": "2022-10-17T02:52:10.770472Z",
@@ -173,6 +176,96 @@
173176
"# print(pred.shape)"
174177
]
175178
},
179+
{
180+
"cell_type": "code",
181+
"execution_count": 9,
182+
"metadata": {},
183+
"outputs": [
184+
{
185+
"data": {
186+
"text/plain": [
187+
"array([[1.77 ],\n",
188+
" [1.228],\n",
189+
" [0.685],\n",
190+
" [0.819]])"
191+
]
192+
},
193+
"execution_count": 9,
194+
"metadata": {},
195+
"output_type": "execute_result"
196+
}
197+
],
198+
"source": [
199+
"input_array"
200+
]
201+
},
202+
{
203+
"cell_type": "code",
204+
"execution_count": 10,
205+
"metadata": {},
206+
"outputs": [],
207+
"source": [
208+
"a = [2,3,4,5]\n",
209+
"b = np.array([a]).T"
210+
]
211+
},
212+
{
213+
"cell_type": "code",
214+
"execution_count": 11,
215+
"metadata": {},
216+
"outputs": [
217+
{
218+
"data": {
219+
"text/plain": [
220+
"array([[2],\n",
221+
" [3],\n",
222+
" [4],\n",
223+
" [5]])"
224+
]
225+
},
226+
"execution_count": 11,
227+
"metadata": {},
228+
"output_type": "execute_result"
229+
}
230+
],
231+
"source": [
232+
"b"
233+
]
234+
},
235+
{
236+
"cell_type": "code",
237+
"execution_count": 7,
238+
"metadata": {},
239+
"outputs": [
240+
{
241+
"data": {
242+
"text/plain": [
243+
"(1, 260)"
244+
]
245+
},
246+
"execution_count": 7,
247+
"metadata": {},
248+
"output_type": "execute_result"
249+
}
250+
],
251+
"source": [
252+
"pred.shape\n"
253+
]
254+
},
255+
{
256+
"cell_type": "code",
257+
"execution_count": null,
258+
"metadata": {},
259+
"outputs": [],
260+
"source": []
261+
},
262+
{
263+
"cell_type": "code",
264+
"execution_count": null,
265+
"metadata": {},
266+
"outputs": [],
267+
"source": []
268+
},
176269
{
177270
"cell_type": "code",
178271
"execution_count": 5,
@@ -831,7 +924,7 @@
831924
},
832925
"hide_input": false,
833926
"kernelspec": {
834-
"display_name": "Python 3 (ipykernel)",
927+
"display_name": "Python 3",
835928
"language": "python",
836929
"name": "python3"
837930
},

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ We provide two machine learning models for uses: **Model A** trained on `[M, R,
1212

1313
**Model A** has a better predictive accuracy, but its application is limited by some difficulties in measuring the tidal Love number `k2` of rocky exoplanets. **Model B** significantly breaks the density-composition degeneracy and accurately predicts the interior properties of rocky exoplanets. Along with the development of space-based observation technologies, orbital or shape observations could be possible to determine the Love number `k2` of rocky exoplanets and hence the machine learning models B and C would be applied more broadly.
1414

15-
1615
## Quick Start
1716
### Step 1:
1817
[Fork and clone](https://help.github.com/articles/fork-a-repo) a copy of the `Rocky_Exoplanets_v2` repository to your local machine.

deepexo/model/model_b.h5

-6 KB
Binary file not shown.

deepexo/model/model_b_scaler.save

0 Bytes
Binary file not shown.

deepexo/rockyplanet.py

Lines changed: 120 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,144 @@
11
from tensorflow.keras.models import load_model
2+
import matplotlib.pyplot as plt
3+
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
24
import numpy as np
3-
5+
import math
46
import mdn
57
import joblib
8+
import os
69

710
OUTPUT_DIMS = 6 # 6 outputs:
811
# 'H2O_radial_frac', 'Mantle_radial_frac', 'Core_radial_frac', 'Core_mass_frac', 'P_CMB', 'T_CMB',
912
N_MIXES = 20 # 20 mixtures
1013

14+
# print(os.getcwd())
15+
model_path = os.path.join(os.getcwd(), "deepexo/model")
16+
1117

1218
class RockyPlanet:
1319
def __init__(self):
14-
self.model_a = load_model(r"model/model_a.h5", custom_objects={
20+
# print('init RockyPlanet')
21+
self.model_a = load_model(os.path.join(model_path, "model_a.h5"), custom_objects={
22+
'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES),
23+
"mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}, compile=False)
24+
self.model_a_scaler = joblib.load(os.path.join(model_path, "model_a_scaler.save"))
25+
self.model_b = load_model(os.path.join(model_path, "model_b.h5"), custom_objects={
1526
'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES),
1627
"mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}, compile=False)
17-
self.model_a_scaler = joblib.load(r"model/model_a_scaler.save")
18-
self.model_b = load_model(r"model/model_b.h5", custom_objects={
19-
'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES), "mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)})
20-
self.model_b_scaler = joblib.load(r"model/model_b_scaler.save")
28+
self.model_b_scaler = joblib.load(os.path.join(model_path, "model_b_scaler.save"))
2129

2230
def predict(self, planet_params):
31+
"""Predictes the .
2332
if len(planet_params) == 3:
2433
print("正在调用inputs={}的model A进行预测".format(planet_params))
2534
scaled_params = self.model_a_scaler.transform(np.array(planet_params).reshape(1, -1))
26-
return self.model_a.predict(scaled_params)[0]
35+
return self.model_a.predict(scaled_params)
2736
elif len(planet_params) == 4:
2837
print("正在调用inputs={}的model B进行预测".format(planet_params))
29-
scaled_params = seyhlf.model_b_scaler.transform(np.array(planet_params).reshape(1, -1))
30-
return self.model_b.predict(scaled_params)[0]
38+
scaled_params = self.model_b_scaler.transform(np.array(planet_params).reshape(1, -1))
39+
return self.model_b.predict(scaled_params)
3140
else:
3241
raise ValueError(
3342
"Invalid number of planet parameters. Expected 3 or 4, but got {}".format(len(planet_params)))
43+
44+
def plot(self, pred):
45+
46+
(y_min1, y_max1, y_min2, y_max2, y_min3, y_max3,
47+
y_min4, y_max4, y_min5, y_max5, y_min6, y_max6) = \
48+
0.00015137, 0.145835, 0.127618, 0.973427, 0.00787023, 0.799449, 1.17976e-06, 0.699986, 10.7182, 1999.49, 1689.37, 5673.87
49+
mus = np.apply_along_axis((lambda a: a[:N_MIXES * OUTPUT_DIMS]), 1, pred)
50+
sigs = np.apply_along_axis((lambda a: a[N_MIXES * OUTPUT_DIMS:2 * N_MIXES * OUTPUT_DIMS]), 1, pred)
51+
pis = np.apply_along_axis((lambda a: mdn.softmax(a[-N_MIXES:])), 1, pred)
52+
53+
for m in range(OUTPUT_DIMS):
54+
locals()['mus' + str(m)] = []
55+
locals()['sigs' + str(m)] = []
56+
for n in range(20):
57+
locals()['mus' + str(m)].append(mus[0][n * OUTPUT_DIMS + m])
58+
locals()['sigs' + str(m)].append(sigs[0][n * OUTPUT_DIMS + m])
59+
x_max = [1, 1, 1, 1, 1, 1]
60+
x_maxlabels = [
61+
y_min1 + (y_max1 - y_min1) * x_max[0],
62+
y_min2 + (y_max2 - y_min2) * x_max[1],
63+
y_min3 + (y_max3 - y_min3) * x_max[2],
64+
y_min4 + (y_max4 - y_min4) * x_max[3],
65+
y_min5 + (y_max5 - y_min5) * x_max[4],
66+
y_min6 + (y_max6 - y_min6) * x_max[5],
67+
]
68+
69+
wrf = [] # H2O_radial_frac
70+
mrf = [] # Mantle_radial_frac
71+
crf = [] # Core_radial_frac
72+
cmf = [] # Core mass frac
73+
pcmb = [] # CMB temperature
74+
tcmb = [] # CMB pressure
75+
76+
for x1 in np.arange(0.02, 0.15, 0.04):
77+
wrf.append((x1 - y_min1) / (x_maxlabels[0] - y_min1) * x_max[0])
78+
for x2 in np.arange(0.2, 1, 0.2):
79+
mrf.append((x2 - y_min2) / (x_maxlabels[1] - y_min2) * x_max[1])
80+
for x3 in np.arange(0.1, 0.9, 0.2):
81+
crf.append((x3 - y_min3) / (x_maxlabels[2] - y_min3) * x_max[2])
82+
for x4 in np.arange(0.1, 0.8, 0.2):
83+
cmf.append((x4 - y_min4) / (x_maxlabels[3] - y_min4) * x_max[3])
84+
for x5 in np.arange(200, 2000, 400):
85+
pcmb.append((x5 - y_min5) / (x_maxlabels[4] - y_min5) * x_max[4])
86+
for x6 in np.arange(2000, 6000, 1000):
87+
tcmb.append((x6 - y_min6) / (x_maxlabels[5] - y_min6) * x_max[5])
88+
89+
xticklabels = [[round(x, 2) for x in np.arange(0.02, 0.15, 0.04)],
90+
[round(x, 2) for x in np.arange(0.2, 1, 0.2)],
91+
[round(x, 2) for x in np.arange(0.1, 0.9, 0.2)],
92+
[round(x, 2) for x in np.arange(0.1, 0.8, 0.2)],
93+
[round(x, 2) for x in np.arange(200, 2000, 400)],
94+
[round(x, 2) for x in np.arange(2000, 6000, 1000)],
95+
]
96+
xticks = [wrf, mrf, crf, cmf, pcmb, tcmb]
97+
98+
colors = [
99+
"steelblue",
100+
"#EA7D1A",
101+
"red",
102+
"gold",
103+
"#2ecc71",
104+
"#03a9f4"
105+
]
106+
predict_label = [
107+
"Water radial fraction",
108+
"Mantle radial fraction",
109+
"Core radial fraction",
110+
"Core mass fraction",
111+
r"CMB pressure ($10^2$ GPa)",
112+
r"CMB temperature ($10^3$ K)"
113+
]
114+
115+
y_label = np.arange(0, 1, 0.001).reshape(-1, 1)
116+
fig = plt.figure(figsize=(6, 6))
117+
fig.subplots_adjust(hspace=0.7, wspace=0.2)
118+
for i in range(OUTPUT_DIMS):
119+
ax = fig.add_subplot(3, 2, i + 1)
120+
mus_ = np.array(locals()['mus' + str(i)])
121+
sigs_ = np.array(locals()['sigs' + str(i)])
122+
factors = 1 / math.sqrt(2 * math.pi) / sigs_
123+
exponent = np.exp(-1 / 2 * np.square((y_label - mus_) / sigs_))
124+
GMM_PDF = np.sum(pis[0] * factors * exponent, axis=1) # Summing multiple Gaussian distributions
125+
plt.plot(
126+
y_label,
127+
GMM_PDF,
128+
color=colors[i],
129+
# label=label,
130+
lw=2,
131+
zorder=10,
132+
)
133+
ax.set_xlim(0, x_max[i])
134+
ax.set_ylim(bottom=0)
135+
136+
ax.xaxis.set_minor_locator(AutoMinorLocator())
137+
ax.yaxis.set_minor_locator(AutoMinorLocator())
138+
# print(xticks[i])
139+
# xticklabels[i]
140+
ax.set_xticks(xticks[i])
141+
ax.set_xticklabels(xticklabels[i])
142+
ax.set_xlabel(predict_label[i])
143+
144+
return plt.show()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ hdf5==1.10.4
22
h5py==2.10.0
33
joblib==0.14.1
44
tensorflow==2.3.0
5-
scikit-learn==0.24.0
5+
scikit-learn==0.22.1
66
tensorflow_probability==0.11.0
77
tensorflow-estimator==2.3.0
88
matplotlib==3.1.3

run.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
from deepexo.rockyplanet import RockyPlanet
22

3-
M = 1.77, # mass in Earth masses
4-
R = 1.228, # radius in Earth radii
5-
cFeMg = 0.685, # bulk Fe/(Mg + Si) (molar ratio)
6-
k2 = 0.819, # tide Love number
3+
M = 1.77 # mass in Earth masses
4+
R = 1.228 # radius in Earth radii
5+
cFeMg = 0.685 # bulk Fe/(Mg + Si) (molar ratio)
6+
k2 = 0.819 # tide Love number
77

8-
planet_params = [M, R, cFeMg, k2]
8+
planet_params = [
9+
M,
10+
R,
11+
cFeMg,
12+
k2,
13+
]
914
rp = RockyPlanet()
10-
pred = rp.predict(planet_params)
15+
pred = rp.predict(planet_params)
16+
rp.plot(pred)
17+
# model_a = rp.load_model("model/model_a.h5")
18+
# model_b = rp.load_model("model/model_b.h5")
19+
# model_a_scaler = rp.load_scaler("model/model_a_scaler.save")
20+
# model_b_scaler = rp.load_scaler("model/model_b_scaler.save")
21+
22+
# pred = rp.predict(planet_params, model_a, model_a_scaler)

0 commit comments

Comments
 (0)