|
1 | 1 | from tensorflow.keras.models import load_model
|
| 2 | +import matplotlib.pyplot as plt |
| 3 | +from matplotlib.ticker import (MultipleLocator, AutoMinorLocator) |
2 | 4 | import numpy as np
|
3 |
| - |
| 5 | +import math |
4 | 6 | import mdn
|
5 | 7 | import joblib
|
| 8 | +import os |
6 | 9 |
|
7 | 10 | OUTPUT_DIMS = 6 # 6 outputs:
|
8 | 11 | # 'H2O_radial_frac', 'Mantle_radial_frac', 'Core_radial_frac', 'Core_mass_frac', 'P_CMB', 'T_CMB',
|
9 | 12 | N_MIXES = 20 # 20 mixtures
|
10 | 13 |
|
| 14 | +# print(os.getcwd()) |
| 15 | +model_path = os.path.join(os.getcwd(), "deepexo/model") |
| 16 | + |
11 | 17 |
|
12 | 18 | class RockyPlanet:
|
13 | 19 | def __init__(self):
|
14 |
| - self.model_a = load_model(r"model/model_a.h5", custom_objects={ |
| 20 | + # print('init RockyPlanet') |
| 21 | + self.model_a = load_model(os.path.join(model_path, "model_a.h5"), custom_objects={ |
| 22 | + 'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES), |
| 23 | + "mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}, compile=False) |
| 24 | + self.model_a_scaler = joblib.load(os.path.join(model_path, "model_a_scaler.save")) |
| 25 | + self.model_b = load_model(os.path.join(model_path, "model_b.h5"), custom_objects={ |
15 | 26 | 'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES),
|
16 | 27 | "mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}, compile=False)
|
17 |
| - self.model_a_scaler = joblib.load(r"model/model_a_scaler.save") |
18 |
| - self.model_b = load_model(r"model/model_b.h5", custom_objects={ |
19 |
| - 'MDN': mdn.MDN(OUTPUT_DIMS, N_MIXES), "mdn_loss_func": mdn.get_mixture_loss_func(OUTPUT_DIMS, N_MIXES)}) |
20 |
| - self.model_b_scaler = joblib.load(r"model/model_b_scaler.save") |
| 28 | + self.model_b_scaler = joblib.load(os.path.join(model_path, "model_b_scaler.save")) |
21 | 29 |
|
22 | 30 | def predict(self, planet_params):
|
| 31 | + """Predictes the . |
23 | 32 | if len(planet_params) == 3:
|
24 | 33 | print("正在调用inputs={}的model A进行预测".format(planet_params))
|
25 | 34 | scaled_params = self.model_a_scaler.transform(np.array(planet_params).reshape(1, -1))
|
26 |
| - return self.model_a.predict(scaled_params)[0] |
| 35 | + return self.model_a.predict(scaled_params) |
27 | 36 | elif len(planet_params) == 4:
|
28 | 37 | print("正在调用inputs={}的model B进行预测".format(planet_params))
|
29 |
| - scaled_params = seyhlf.model_b_scaler.transform(np.array(planet_params).reshape(1, -1)) |
30 |
| - return self.model_b.predict(scaled_params)[0] |
| 38 | + scaled_params = self.model_b_scaler.transform(np.array(planet_params).reshape(1, -1)) |
| 39 | + return self.model_b.predict(scaled_params) |
31 | 40 | else:
|
32 | 41 | raise ValueError(
|
33 | 42 | "Invalid number of planet parameters. Expected 3 or 4, but got {}".format(len(planet_params)))
|
| 43 | + |
| 44 | + def plot(self, pred): |
| 45 | + |
| 46 | + (y_min1, y_max1, y_min2, y_max2, y_min3, y_max3, |
| 47 | + y_min4, y_max4, y_min5, y_max5, y_min6, y_max6) = \ |
| 48 | + 0.00015137, 0.145835, 0.127618, 0.973427, 0.00787023, 0.799449, 1.17976e-06, 0.699986, 10.7182, 1999.49, 1689.37, 5673.87 |
| 49 | + mus = np.apply_along_axis((lambda a: a[:N_MIXES * OUTPUT_DIMS]), 1, pred) |
| 50 | + sigs = np.apply_along_axis((lambda a: a[N_MIXES * OUTPUT_DIMS:2 * N_MIXES * OUTPUT_DIMS]), 1, pred) |
| 51 | + pis = np.apply_along_axis((lambda a: mdn.softmax(a[-N_MIXES:])), 1, pred) |
| 52 | + |
| 53 | + for m in range(OUTPUT_DIMS): |
| 54 | + locals()['mus' + str(m)] = [] |
| 55 | + locals()['sigs' + str(m)] = [] |
| 56 | + for n in range(20): |
| 57 | + locals()['mus' + str(m)].append(mus[0][n * OUTPUT_DIMS + m]) |
| 58 | + locals()['sigs' + str(m)].append(sigs[0][n * OUTPUT_DIMS + m]) |
| 59 | + x_max = [1, 1, 1, 1, 1, 1] |
| 60 | + x_maxlabels = [ |
| 61 | + y_min1 + (y_max1 - y_min1) * x_max[0], |
| 62 | + y_min2 + (y_max2 - y_min2) * x_max[1], |
| 63 | + y_min3 + (y_max3 - y_min3) * x_max[2], |
| 64 | + y_min4 + (y_max4 - y_min4) * x_max[3], |
| 65 | + y_min5 + (y_max5 - y_min5) * x_max[4], |
| 66 | + y_min6 + (y_max6 - y_min6) * x_max[5], |
| 67 | + ] |
| 68 | + |
| 69 | + wrf = [] # H2O_radial_frac |
| 70 | + mrf = [] # Mantle_radial_frac |
| 71 | + crf = [] # Core_radial_frac |
| 72 | + cmf = [] # Core mass frac |
| 73 | + pcmb = [] # CMB temperature |
| 74 | + tcmb = [] # CMB pressure |
| 75 | + |
| 76 | + for x1 in np.arange(0.02, 0.15, 0.04): |
| 77 | + wrf.append((x1 - y_min1) / (x_maxlabels[0] - y_min1) * x_max[0]) |
| 78 | + for x2 in np.arange(0.2, 1, 0.2): |
| 79 | + mrf.append((x2 - y_min2) / (x_maxlabels[1] - y_min2) * x_max[1]) |
| 80 | + for x3 in np.arange(0.1, 0.9, 0.2): |
| 81 | + crf.append((x3 - y_min3) / (x_maxlabels[2] - y_min3) * x_max[2]) |
| 82 | + for x4 in np.arange(0.1, 0.8, 0.2): |
| 83 | + cmf.append((x4 - y_min4) / (x_maxlabels[3] - y_min4) * x_max[3]) |
| 84 | + for x5 in np.arange(200, 2000, 400): |
| 85 | + pcmb.append((x5 - y_min5) / (x_maxlabels[4] - y_min5) * x_max[4]) |
| 86 | + for x6 in np.arange(2000, 6000, 1000): |
| 87 | + tcmb.append((x6 - y_min6) / (x_maxlabels[5] - y_min6) * x_max[5]) |
| 88 | + |
| 89 | + xticklabels = [[round(x, 2) for x in np.arange(0.02, 0.15, 0.04)], |
| 90 | + [round(x, 2) for x in np.arange(0.2, 1, 0.2)], |
| 91 | + [round(x, 2) for x in np.arange(0.1, 0.9, 0.2)], |
| 92 | + [round(x, 2) for x in np.arange(0.1, 0.8, 0.2)], |
| 93 | + [round(x, 2) for x in np.arange(200, 2000, 400)], |
| 94 | + [round(x, 2) for x in np.arange(2000, 6000, 1000)], |
| 95 | + ] |
| 96 | + xticks = [wrf, mrf, crf, cmf, pcmb, tcmb] |
| 97 | + |
| 98 | + colors = [ |
| 99 | + "steelblue", |
| 100 | + "#EA7D1A", |
| 101 | + "red", |
| 102 | + "gold", |
| 103 | + "#2ecc71", |
| 104 | + "#03a9f4" |
| 105 | + ] |
| 106 | + predict_label = [ |
| 107 | + "Water radial fraction", |
| 108 | + "Mantle radial fraction", |
| 109 | + "Core radial fraction", |
| 110 | + "Core mass fraction", |
| 111 | + r"CMB pressure ($10^2$ GPa)", |
| 112 | + r"CMB temperature ($10^3$ K)" |
| 113 | + ] |
| 114 | + |
| 115 | + y_label = np.arange(0, 1, 0.001).reshape(-1, 1) |
| 116 | + fig = plt.figure(figsize=(6, 6)) |
| 117 | + fig.subplots_adjust(hspace=0.7, wspace=0.2) |
| 118 | + for i in range(OUTPUT_DIMS): |
| 119 | + ax = fig.add_subplot(3, 2, i + 1) |
| 120 | + mus_ = np.array(locals()['mus' + str(i)]) |
| 121 | + sigs_ = np.array(locals()['sigs' + str(i)]) |
| 122 | + factors = 1 / math.sqrt(2 * math.pi) / sigs_ |
| 123 | + exponent = np.exp(-1 / 2 * np.square((y_label - mus_) / sigs_)) |
| 124 | + GMM_PDF = np.sum(pis[0] * factors * exponent, axis=1) # Summing multiple Gaussian distributions |
| 125 | + plt.plot( |
| 126 | + y_label, |
| 127 | + GMM_PDF, |
| 128 | + color=colors[i], |
| 129 | + # label=label, |
| 130 | + lw=2, |
| 131 | + zorder=10, |
| 132 | + ) |
| 133 | + ax.set_xlim(0, x_max[i]) |
| 134 | + ax.set_ylim(bottom=0) |
| 135 | + |
| 136 | + ax.xaxis.set_minor_locator(AutoMinorLocator()) |
| 137 | + ax.yaxis.set_minor_locator(AutoMinorLocator()) |
| 138 | + # print(xticks[i]) |
| 139 | + # xticklabels[i] |
| 140 | + ax.set_xticks(xticks[i]) |
| 141 | + ax.set_xticklabels(xticklabels[i]) |
| 142 | + ax.set_xlabel(predict_label[i]) |
| 143 | + |
| 144 | + return plt.show() |
0 commit comments