Skip to content

Commit 867ef33

Browse files
committed
Changes
1 parent 4726527 commit 867ef33

File tree

9 files changed

+71
-25
lines changed

9 files changed

+71
-25
lines changed

DNN_pred_k2_20220924.h5

-235 KB
Binary file not shown.

MDN_two_planets_prediction.ipynb

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"import matplotlib.pyplot as plt\n",
1616
"import matplotlib.ticker as tck\n",
1717
"from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)\n",
18-
"import seaborn as sns\n",
18+
"# import seaborn as sns\n",
1919
"\n",
2020
"import pandas as pd\n",
2121
"import numpy as np\n",
@@ -26,7 +26,7 @@
2626
"import mdn\n",
2727
"import joblib\n",
2828
"\n",
29-
"from sklearn.model_selection import train_test_split\n",
29+
"# from sklearn.model_selection import train_test_split\n",
3030
"\n",
3131
"import matplotlib\n",
3232
"matplotlib.rcParams['svg.fonttype'] = 'none'\n",
@@ -53,7 +53,24 @@
5353
},
5454
"id": "OHSSb5g_iWsK"
5555
},
56-
"outputs": [],
56+
"outputs": [
57+
{
58+
"ename": "ModuleNotFoundError",
59+
"evalue": "No module named 'sklearn'",
60+
"output_type": "error",
61+
"traceback": [
62+
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
63+
"\u001B[1;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
64+
"\u001B[1;32m~\\AppData\\Local\\Temp\\ipykernel_19692\\2450612377.py\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 22\u001B[0m \u001B[0mcompile\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 23\u001B[0m )\n\u001B[1;32m---> 24\u001B[1;33m \u001B[0minput_scaler\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mjoblib\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"MDN_MRFe(Mg+Si)k2_Xscaler_20220917.save\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 25\u001B[0m \u001B[0moutput_scaler\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mjoblib\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"MDN_MRFe(Mg+Si)k2_yscaler_20220917.save\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
65+
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\site-packages\\joblib\\numpy_pickle.py\u001B[0m in \u001B[0;36mload\u001B[1;34m(filename, mmap_mode)\u001B[0m\n\u001B[0;32m 603\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mload_compatibility\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mfobj\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 604\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 605\u001B[1;33m \u001B[0mobj\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0m_unpickle\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mfobj\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mfilename\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mmmap_mode\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 606\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 607\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mobj\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
66+
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\site-packages\\joblib\\numpy_pickle.py\u001B[0m in \u001B[0;36m_unpickle\u001B[1;34m(fobj, filename, mmap_mode)\u001B[0m\n\u001B[0;32m 527\u001B[0m \u001B[0mobj\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 528\u001B[0m \u001B[1;32mtry\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 529\u001B[1;33m \u001B[0mobj\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0munpickler\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 530\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0munpickler\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mcompat_mode\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 531\u001B[0m warnings.warn(\"The file '%s' has been generated with a \"\n",
67+
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\pickle.py\u001B[0m in \u001B[0;36mload\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 1086\u001B[0m \u001B[1;32mraise\u001B[0m \u001B[0mEOFError\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1087\u001B[0m \u001B[1;32massert\u001B[0m \u001B[0misinstance\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkey\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mbytes_types\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1088\u001B[1;33m \u001B[0mdispatch\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mkey\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;36m0\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 1089\u001B[0m \u001B[1;32mexcept\u001B[0m \u001B[0m_Stop\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0mstopinst\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1090\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mstopinst\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mvalue\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
68+
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\pickle.py\u001B[0m in \u001B[0;36mload_global\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 1374\u001B[0m \u001B[0mmodule\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mreadline\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m-\u001B[0m\u001B[1;36m1\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mdecode\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"utf-8\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1375\u001B[0m \u001B[0mname\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mreadline\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m-\u001B[0m\u001B[1;36m1\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mdecode\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"utf-8\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1376\u001B[1;33m \u001B[0mklass\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mfind_class\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodule\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mname\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 1377\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mappend\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mklass\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1378\u001B[0m \u001B[0mdispatch\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mGLOBAL\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;36m0\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mload_global\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
69+
"\u001B[1;32m~\\anaconda3\\envs\\deepexo\\lib\\pickle.py\u001B[0m in \u001B[0;36mfind_class\u001B[1;34m(self, module, name)\u001B[0m\n\u001B[0;32m 1424\u001B[0m \u001B[1;32melif\u001B[0m \u001B[0mmodule\u001B[0m \u001B[1;32min\u001B[0m \u001B[0m_compat_pickle\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mIMPORT_MAPPING\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1425\u001B[0m \u001B[0mmodule\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0m_compat_pickle\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mIMPORT_MAPPING\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mmodule\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1426\u001B[1;33m \u001B[0m__import__\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodule\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mlevel\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;36m0\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 1427\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mproto\u001B[0m \u001B[1;33m>=\u001B[0m \u001B[1;36m4\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1428\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0m_getattribute\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0msys\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmodules\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mmodule\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mname\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;36m0\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
70+
"\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'sklearn'"
71+
]
72+
}
73+
],
5774
"source": [
5875
"input_parameters = [\n",
5976
" 'Mass', \n",
@@ -814,7 +831,7 @@
814831
},
815832
"hide_input": false,
816833
"kernelspec": {
817-
"display_name": "Python 3",
834+
"display_name": "Python 3 (ipykernel)",
818835
"language": "python",
819836
"name": "python3"
820837
},

README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ We provide two machine learning models for uses: **Model A** trained on `[M, R,
2121
Download [`Anaconda`](https://www.anaconda.com/products/individual#Downloads) and install it on your machine.
2222
Create a `conda` environment called `Rocky_Exoplanets` and install all the necessary dependencies:
2323

24-
$ conda create -n Rocky_Exoplanets pip python=3.7.6 keras-mdn-layer jupyter
24+
$ conda create -n Rocky_Exoplanets pip python=3.7.6 jupyter
2525

2626
### Step 3:
2727
Activate the `Rocky_Exoplanets` environment:
@@ -46,12 +46,15 @@ Open `Jupyter Notebook` and load the file `MDN_two_planets_prediction.ipynb`:
4646
At this point you are ready to start investigating the interiors of rocky exoplanets!
4747

4848
## Usage
49+
4950
```python
50-
from deepexo import RockyPlanet
51+
from deepexo import rockyplanet
52+
5153
# Kepler-78b
52-
M = 1.77 # mass in Earth masses
53-
R = 1.228 # radius in Earth radii
54-
cFeMg = 0.685 # bulk Fe/(Mg + Si) (molar ratio)
54+
M = 1.77 # mass in Earth masses
55+
R = 1.228 # radius in Earth radii
56+
cFeMg = 0.685 # bulk Fe/(Mg + Si) (molar ratio)
57+
5558

5659
```
5760
## References

deepexo/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
from deepexo.rockyplanet import PlanetSynth

deepexo/model/model_b.h5

7.63 MB
Binary file not shown.
File renamed without changes.

deepexo/rockyplanet.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,31 @@
33

44
import mdn
55
import joblib
6-
from numpy.typing import ArrayLike
7-
class RockyPlanet:
8-
def predict(self, planet_params: ArrayLike):
9-
"""
10-
Args:
11-
Array of floats of the planet parameters
12-
in the following order: [M, R, k2].
13-
The following input ranges are supported:
14-
M: 0.1 < M [M_earth] < 30
15-
R: 0.1 < R [R_earth] < 30
16-
k2: 0 < k2 < 0.5
17-
Returns:
186

19-
"""
7+
OUTPUT_DIMS = 6 # 6 outputs:
8+
# 'H2O_radial_frac', 'Mantle_radial_frac', 'Core_radial_frac', 'Core_mass_frac', 'P_CMB', 'T_CMB',
9+
N_MIXES = 20 # 20 mixtures
10+
11+
12+
class RockyPlanet:
13+
def __init__(self):
14+
self.model_a = load_model(r"model/model_a.h5", custom_objects={
15+
'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES),
16+
"mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}, compile=False)
17+
self.model_a_scaler = joblib.load(r"model/model_a_scaler.save")
18+
self.model_b = load_model(r"model/model_b.h5", custom_objects={
19+
'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES), "mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)})
20+
self.model_b_scaler = joblib.load(r"model/model_b_scaler.save")
2021

22+
def predict(self, planet_params):
23+
if len(planet_params) == 3:
24+
print("正在调用inputs={}的model A进行预测".format(planet_params))
25+
scaled_params = self.model_a_scaler.transform(np.array(planet_params).reshape(1, -1))
26+
return self.model_a.predict(scaled_params)[0]
27+
elif len(planet_params) == 4:
28+
print("正在调用inputs={}的model B进行预测".format(planet_params))
29+
scaled_params = seyhlf.model_b_scaler.transform(np.array(planet_params).reshape(1, -1))
30+
return self.model_b.predict(scaled_params)[0]
31+
else:
32+
raise ValueError(
33+
"Invalid number of planet parameters. Expected 3 or 4, but got {}".format(len(planet_params)))

requirements.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
hdf5==1.10.4
2+
h5py==2.10.0
13
joblib==0.14.1
4+
tensorflow==2.3.0
5+
scikit-learn==0.24.0
26
tensorflow_probability==0.11.0
7+
tensorflow-estimator==2.3.0
38
matplotlib==3.1.3
4-
tensorflow==2.3.0
59
numpy==1.18.0
610
pandas==1.2.3
7-
xlrd==1.2.0
11+
xlrd==1.2.0

run.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from deepexo.rockyplanet import RockyPlanet
2+
3+
M = 1.77, # mass in Earth masses
4+
R = 1.228, # radius in Earth radii
5+
cFeMg = 0.685, # bulk Fe/(Mg + Si) (molar ratio)
6+
k2 = 0.819, # tide Love number
7+
8+
planet_params = [M, R, cFeMg, k2]
9+
rp = RockyPlanet()
10+
pred = rp.predict(planet_params)

0 commit comments

Comments
 (0)