Skip to content

Commit a90babf

Browse files
committed
ISEFlow publication commit
1 parent cd04ddc commit a90babf

File tree

8 files changed

+523
-64
lines changed

8 files changed

+523
-64
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Ice Sheet Emulator (ISE) for ISMIP6 Emulation of Sea Level Rise
22

3-
**Peter Van Katwyk $^{1}$, Karianne Bergen $^{1,2}$**
3+
**Peter Van Katwyk $^{1, 2, 3}$**
44

55
$^{1}$ Department of Earth, Environmental, and Planetary Sciences, Brown University
66
$^{2}$ Data Science Initiative, Brown University
7+
$^{3}$ Institute at Brown for Environmental and Society, Brown University
78

8-
This repository contains source code for processing climate forcings from [ISMIP6 Antarctice Ice Sheet (AIS) simulations](https://app.globus.org/file-manager?origin_id=ad1a6ed8-4de0-4490-93a9-8258931766c7&origin_path=%2FAIS%2F) and creating and testing ice sheet emulators. Neural network based emulators can be trained and compared to traditional gaussian process emulators along with other functions to help aid in analysis of performance. This repository currently only supports Antarctic ice sheet emulation, but in the future GrIS and others may be included as well.
9+
This repository contains source code for processing climate forcings from [ISMIP6 simulations](https://app.globus.org/file-manager?origin_id=ad1a6ed8-4de0-4490-93a9-8258931766c7&origin_path=%2FAIS%2F) and creating and testing ice sheet emulators. Neural network based emulators, such as the most recent version, ISEFlow, can be trained and compared to traditional gaussian process emulators along with other functions to help aid in analysis of performance. This repository supports Antarctic and Greenland ice sheet emulation.
910

1011
Documentation can be found at here: <https://brown-sciml.github.io/ise/>.
1112

1213
To access code for exact replication of "A Variational LSTM Emulator of Sea Level Contribution From the Antarctic Ice Sheet", see the release [https://github.com/Brown-SciML/ise/releases/tag/v1.0.0](https://github.com/Brown-SciML/ise/releases/tag/v1.0.0).
14+
To access code for exact replication of "ISEFlow: A Flow-Based Neural Network Emulator for Improved Sea Level Projections and Uncertainty Quantification", see the release []().
1315

1416
*This repository is a work in progress that is actively being updated and improved. Feel free to contact Peter Van Katwyk, Ph.D. student @ Brown University at peter_van_katwyk@brown.edu with further questions.*

examples/ISEFlow_from_NC.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from ise.models.ISEFlow import ISEFlow_AIS
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
5+
iseflowais = ISEFlow_AIS.load(version="v1.0.0", )
6+
7+
# DATA
8+
year = np.arange(2015, 2101)
9+
pr_anomaly = np.array([-7.0884660e-07, 3.3546070e-06, 1.6510604e-06, 1.3058835e-06, -1.8460897e-06, -1.8209784e-06, -2.8631115e-07, 9.7550980e-07, 1.8837744e-06, 1.0278583e-06, -5.3213130e-07, 9.8659150e-07, 3.0667563e-06, 1.4635594e-06, 1.6326882e-06, 1.9636754e-07, 1.6438966e-06, 1.6327829e-06, 4.2836740e-06, -1.6136711e-06, -3.2394215e-07, 8.5686120e-07, 1.1699207e-06, 3.1250070e-06, -1.7328695e-06, 1.5317561e-06, -2.5603040e-07, 3.5395046e-06, 2.0336056e-06, 1.4168240e-06, 3.6651910e-06, 3.8242396e-07, 2.5236104e-06, 2.3667692e-06, 4.4192740e-06, -1.8508446e-06, 5.3830777e-06, 3.4818756e-06, 1.1071076e-06, 3.3175084e-07, 2.1374274e-06, 3.4457967e-06, 1.3095015e-06, 2.4390830e-06, 2.7032495e-06, 2.3734970e-06, 3.4120333e-06, 7.1224844e-07, 1.1166115e-06, 1.3938886e-06, 2.8442582e-06, 1.1369078e-06, 7.5225375e-07, 3.3071670e-06, 9.2647576e-07, 2.8921231e-06, 1.5258190e-06, 2.3025927e-06, 2.7141216e-06, 4.7282086e-07, 1.7521104e-06, 2.2258650e-06, 2.5074578e-06, 6.3012344e-06, 4.3367270e-06, 1.3387919e-06, 2.0412472e-06, 5.5698990e-06, -4.4983244e-07, 4.9986234e-06, 3.4896400e-06, 3.2939452e-06, 3.1120549e-06, 5.1982165e-06, 3.4851853e-06, 2.3888053e-06, 3.3045646e-06, 5.7656130e-06, 2.5516167e-06, 9.7209140e-07, 2.7104174e-06, 2.5045779e-06, 5.0267950e-06, 3.4584243e-06, 4.5173547e-06, 7.0501980e-06])
10+
evspsbl_anomaly = np.array([-1.7997656e-06, 8.4536487e-07, -3.1830848e-07, -2.6885890e-07, 2.2850169e-07, -1.0624783e-06, 1.4247221e-06, 1.1629505e-06, 7.3894020e-07, 1.1523747e-06, -3.0657730e-07, -1.2494169e-06, 1.4154864e-07, 1.6318978e-06, 6.6845054e-08, -1.5337197e-06, 1.0680320e-06, 3.3650886e-07, 7.2351054e-07, -2.8386344e-07, -5.0218010e-07, -9.0324650e-07, 1.8473976e-06, 2.7624053e-06, -3.3211627e-09, -1.2328867e-06, -7.4078486e-07, 1.8665929e-06, -8.5686236e-07, -9.2208126e-07, -4.1323662e-07, 8.0743240e-09, 3.8672812e-07, -6.4713157e-07, -8.9784186e-07, -1.5912426e-06, 1.6034195e-06, 1.5027766e-06, -1.4378597e-06, 9.1651380e-07, 1.2373920e-06, 7.4936776e-07, -9.1188990e-07, 1.5164496e-06, 9.4740574e-07, -2.3654757e-06, 1.1783083e-06, -5.9225925e-07, 2.4704977e-06, 1.0327708e-07, -3.8991467e-07, 1.6661602e-06, -6.2760050e-07, 1.3399692e-06, 1.9969345e-06, -1.1019089e-07, -6.4994750e-07, 6.8705290e-07, -1.4576833e-06, 4.2027610e-07, -3.9619770e-07, 3.3307506e-07, 5.0004815e-07, 6.5702060e-07, 2.9295900e-07, -5.8118560e-07, -1.3942489e-06, 1.8514128e-06, -7.6820600e-07, 4.3472804e-07, 1.5344255e-06, 9.9030400e-08, -1.2802110e-06, 1.2338684e-06, 1.2415984e-06, -1.0605782e-09, -8.5692153e-07, 1.3089436e-06, 1.4610205e-06, -1.7224523e-06, 8.7178023e-07, 7.1715260e-07, -2.4332334e-07, 1.9094662e-06, 1.6035150e-06, 1.2819792e-06])
11+
mrro_anomaly = np.array([ 9.14532450e-09, -1.04553575e-08, 9.64602600e-10, -2.57455270e-08,-4.20736160e-08, -3.78904020e-08, 1.00503700e-08, 2.71937670e-08, 3.14256670e-08, -1.26842705e-08, -1.93106080e-08, 4.11426240e-08, 1.26591990e-08, -3.74505480e-08, -2.38847800e-09, 5.83510600e-08,-1.32395360e-08, -2.36836480e-09, -5.31146060e-09, 2.26311640e-08, 1.94292900e-08, 1.57301560e-07, 1.76516950e-08, 3.69438370e-09, 2.78125740e-09, 2.32886230e-08, -1.13636180e-09, 4.15595200e-08, 2.44487060e-08, 3.69915620e-08, 5.37579500e-08, 6.49251250e-08,-2.62033420e-08, 2.42323500e-08, 5.30387250e-09, 6.31726300e-08, 2.96837270e-08, 1.95823020e-09, 6.91571100e-08, -2.45743660e-08,-1.79463950e-08, 1.87269790e-08, 3.07978250e-08, -1.72344790e-08, 3.80653660e-08, 5.87809370e-08, 1.97393090e-08, 5.83083980e-08,-3.69657550e-08, 5.42363560e-08, 2.58655830e-08, 4.59774000e-08,-4.13071720e-10, 2.81270150e-08, -2.89719240e-08, -5.64607030e-10, 7.78114400e-09, -1.56779370e-08, 1.09255420e-07, 6.70794550e-08, 8.00081100e-09, 2.62701310e-08, 9.69834400e-08, 1.26523400e-07, 1.17642600e-07, 2.68955600e-08, 1.15547470e-07, 2.88342490e-08, 5.08320700e-08, 1.57037930e-07, 7.51390350e-08, 4.73127630e-08, 1.09078314e-07, 1.45109870e-07, -2.93809070e-09, 7.75231600e-08, 8.61493300e-08, 5.89316080e-08, -7.94276200e-09, 1.55970100e-07, 1.00888755e-07, 7.78078760e-08, 1.84023630e-07, 1.17026595e-07, 1.09118860e-07, 5.72854500e-08])
12+
smb_anomaly = np.array([ 1.0817737e-06, 2.5196978e-06, 1.9684041e-06, 1.6004882e-06,-2.0325180e-06, -7.2060976e-07, -1.7210832e-06, -2.1463464e-07, 1.1134085e-06, -1.1183209e-07, -2.0624336e-07, 2.1948658e-06, 2.9125483e-06, -1.3088778e-07, 1.5682317e-06, 1.6717360e-06, 5.8910400e-07, 1.2986424e-06, 3.5654746e-06, -1.3524389e-06, 1.5880867e-07, 1.6028063e-06, -6.9512873e-07, 3.5890747e-07,-1.7323296e-06, 2.7413540e-06, 4.8589070e-07, 1.6313521e-06, 2.8660190e-06, 2.3019136e-06, 4.0246696e-06, 3.0942448e-07, 2.1630856e-06, 2.9896692e-06, 5.3118115e-06, -3.2277464e-07, 3.7499747e-06, 1.9771405e-06, 2.4758103e-06, -5.6018850e-07, 9.1798154e-07, 2.6777027e-06, 2.1905935e-06, 9.3986740e-07, 1.7177789e-06, 4.6801915e-06, 2.2139857e-06, 1.2461996e-06,-1.3169205e-06, 1.2363750e-06, 3.2083070e-06, -5.7522980e-07, 1.3802672e-06, 1.9390710e-06, -1.0414868e-06, 3.0028789e-06, 2.1679853e-06, 1.6312173e-06, 4.0625496e-06, -1.4534731e-08, 2.1403070e-06, 1.8665197e-06, 1.9104264e-06, 5.5176910e-06, 3.9261254e-06, 1.8930818e-06, 3.3199485e-06, 3.6896517e-06, 2.6754162e-07, 4.4068574e-06, 1.8800760e-06, 3.1476022e-06, 4.2831875e-06, 3.8192384e-06, 2.2465244e-06, 2.3123425e-06, 4.0753375e-06, 4.3977380e-06, 1.0985389e-06, 2.5385737e-06, 1.7377488e-06, 1.7096173e-06, 5.0860950e-06, 1.4319313e-06, 2.8047202e-06, 5.7109332e-06])
13+
ts_anomaly = np.array([-0.6466742 , 0.00770213, -0.12585382, -0.34453338, -1.6241165 ,-0.974687 , -0.05863803, 0.93854046, 0.56900305, -0.6451154 ,-1.9634529 , -0.40643853, 0.23045924, -0.6870346 , -0.8890411 , 0.37614653, 0.67980075, 0.6205493 , 0.5144017 , 0.5007241 , 0.9710826 , 1.4140744 , 1.3183008 , 0.2768844 , -0.9116946 , 0.6314017 , 0.56855744, 0.6563098 , -0.18145359, -0.23351248, 0.63984925, 1.2217962 , -0.14435144, 0.12519707, -0.09563913, 0.66578555, 1.6812397 , -0.17464125, 0.30024293, -0.7873655 , 0.2609705 , 0.71260154, -0.23385662, -0.04084974, -0.18540329, 1.2981458 , 1.2119863 , 0.8429037 , 0.50004035, 1.1290026 , 0.849003 , 0.48997545, 0.16585606, 0.5025321 , -0.31758532,-0.21317828, 0.25390372, -0.32966253, 0.89048064, 1.0131814 , 0.7951362 , 1.2391979 , 2.1482313 , 1.2157921 , 1.2971075 , 0.06134838, 1.1723769 , 0.8362014 , 0.6327716 , 1.9348097 , 1.9065841 , 0.98560363, 1.6695279 , 1.8097014 , 1.2889434 , 1.3257821 , 1.3567889 , 1.398379 , 0.12701766, 1.8586403 , 2.803153 , 2.7678387 , 2.795791 , 2.046689 , 1.6896557 , 1.1954719 ])
14+
ocean_thermal_forcing = np.array([3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.952802, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.951522, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.954892, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.961842, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9666321, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.9701214, 3.965213, 3.965213 ])
15+
ocean_thermal_forcing = np.array([4.052609 , 4.048029 , 4.025996 , 4.0376115, 4.073593 , 4.082889 ,4.068752 , 4.0792937, 4.07266 , 4.056613 , 4.088873 , 4.0866995,4.080259 , 4.0805173, 4.0915833, 4.086466 , 4.0837874, 4.0866885,4.0908504, 4.0955696, 4.11422 , 4.1292214, 4.118178 , 4.122908 ,4.133814 , 4.14126 , 4.129435 , 4.1385765, 4.15981 , 4.174478 ,4.171833 , 4.1734524, 4.1905923, 4.202703 , 4.203176 , 4.2121515,4.208194 , 4.236447 , 4.23815 , 4.260915 , 4.256757 , 4.265933 ,4.283288 , 4.29747 , 4.2998824, 4.3114376, 4.3152666, 4.3286147,4.3247986, 4.35383 , 4.3568287, 4.361311 , 4.371272 , 4.3807735,4.389557 , 4.394062 , 4.394604 , 4.407792 , 4.4131284, 4.4424253,4.461234 , 4.4478235, 4.473922 , 4.47514 , 4.474978 , 4.503846 ,4.519425 , 4.532303 , 4.5792913, 4.581029 , 4.5686707, 4.5728297,4.578795 , 4.588186 , 4.6041985, 4.609596 , 4.6358595, 4.64103 ,4.6610756, 4.6761293, 4.6807904, 4.681794 , 4.688118 , 4.707095 ,4.709125 , 4.7083187])
16+
ocean_salinity = np.array([34.538155, 34.54216 , 34.544907, 34.54777 , 34.54195 , 34.538666,34.534573, 34.534435, 34.538467, 34.541798, 34.54083 , 34.540863,34.5443 , 34.548077, 34.551094, 34.549007, 34.54491 , 34.541965,34.545845, 34.54476 , 34.540836, 34.536514, 34.535862, 34.540535,34.54106 , 34.542145, 34.541264, 34.538086, 34.53983 , 34.544174,34.543217, 34.54192 , 34.542164, 34.546356, 34.549465, 34.546703,34.545856, 34.547813, 34.548893, 34.54803 , 34.551342, 34.54866 ,34.55094 , 34.552845, 34.555645, 34.557346, 34.55776 , 34.56115 ,34.557613, 34.554943, 34.558067, 34.561043, 34.562572, 34.562527,34.558205, 34.56071 , 34.56081 , 34.56635 , 34.566048, 34.562794,34.559826, 34.558685, 34.558525, 34.555386, 34.556175, 34.559002,34.559345, 34.563156, 34.56612 , 34.56643 , 34.561913, 34.56346 ,34.56588 , 34.55718 , 34.56051 , 34.55849 , 34.55934 , 34.562157,34.56532 , 34.561005, 34.563007, 34.56287 , 34.563423, 34.562115,34.566116, 34.566475])
17+
ocean_temp = np.array([1.4597255, 1.454916 , 1.4327273, 1.444178 , 1.4804901, 1.489976 ,1.4760449, 1.4865896, 1.4797571, 1.4635196, 1.4958361, 1.4936603,1.4870231, 1.4870658, 1.4979594, 1.492959 , 1.4905169, 1.493585 ,1.4975262, 1.5022023, 1.5211732, 1.5363928, 1.5254219, 1.5298871,1.5407618, 1.548147 , 1.5363714, 1.5456947, 1.566828 , 1.5812478,1.5786566, 1.5803497, 1.5974761, 1.6093469, 1.6096419, 1.6187748,1.6148634, 1.6430062, 1.6446307, 1.6674396, 1.6631144, 1.6724435,1.6896647, 1.7037407, 1.7059923, 1.7174506, 1.7212558, 1.73441 ,1.730796 , 1.7599787, 1.7627985, 1.7671117, 1.7769841, 1.7864847,1.795518 , 1.7998635, 1.8004041, 1.8132827, 1.8186407, 1.8481231,1.867101 , 1.85375 , 1.8798618, 1.8812587, 1.8810492, 1.9097593,1.9253169, 1.9379777, 1.9847963, 1.9864471, 1.9743943, 1.9784856,1.9843125, 1.9941779, 2.0100205, 2.0155215, 2.0417492, 2.046756 ,2.0666232, 2.0819116, 2.0864294, 2.0873764, 2.0937738, 2.1128209,2.114625 , 2.1137898])
18+
19+
# CHARS (see Table A1 Seroussi et al. 2020)
20+
initial_year = 1980
21+
numerics = 'fd'
22+
stress_balance = 'ho'
23+
resolution = 16
24+
init_method = "da"
25+
melt_in_floating_cells = "floating condition"
26+
icefront_migration = "str"
27+
ocean_forcing_type = "open"
28+
ocean_sensitivity = "low"
29+
ice_shelf_fracture = False
30+
open_melt_type = "picop"
31+
standard_melt_type = "nonlocal"
32+
33+
# PREDICT
34+
pred, uq = iseflowais.predict(
35+
year, pr_anomaly, evspsbl_anomaly, mrro_anomaly, smb_anomaly, ts_anomaly,
36+
ocean_thermal_forcing, ocean_salinity, ocean_temp,
37+
initial_year, numerics, stress_balance, resolution, init_method,
38+
melt_in_floating_cells, icefront_migration, ocean_forcing_type,
39+
ocean_sensitivity, ice_shelf_fracture, open_melt_type, standard_melt_type
40+
)
41+
42+
plt.figure(figsize=(10, 6))
43+
pred = pred.squeeze()
44+
aleatoric = uq['aleatoric'].squeeze()
45+
epistemic = uq['epistemic'].squeeze()
46+
total = aleatoric + epistemic.squeeze()
47+
plt.fill_between(year, pred - epistemic, pred + epistemic, color='blue', alpha=0.5, label='Emulator Uncertainty (2$\sigma$)')
48+
plt.fill_between(year, pred - total, pred + total, color='green', alpha=0.5, label='Data Coverage Uncertainty')
49+
plt.plot(year, pred, color='red', label='Prediction')
50+
51+
plt.xlabel('Year')
52+
plt.ylabel('Projected Sea Level Equivalent (mm SLE)')
53+
plt.title('ISEFlow-AIS Sea Level Projection')
54+
plt.legend()
55+
plt.grid(True)
56+
plt.savefig('ISEFlow_from_NC.png')
57+
plt.show()
58+
plt.close('all')
59+
60+
61+
# plt.figure(figsize=(10, 6))
62+
# for init_method in ['da', 'da*', 'da+', 'eq', 'sp', 'sp+']:
63+
# pred, uq = iseflowais.predict(
64+
# year, pr_anomaly, evspsbl_anomaly, mrro_anomaly, smb_anomaly, ts_anomaly,
65+
# ocean_thermal_forcing, ocean_salinity, ocean_temp,
66+
# initial_year, numerics, stress_balance, resolution, init_method,
67+
# melt_in_floating_cells, icefront_migration, ocean_forcing_type,
68+
# ocean_sensitivity, ice_shelf_fracture, open_melt_type, standard_melt_type
69+
# )
70+
# pred = pred.squeeze()
71+
72+
73+
# plt.plot(year, pred, label=init_method)
74+
# plt.xlabel('Year')
75+
# plt.ylabel('Projected Sea Level Equivalent (mm SLE)')
76+
# plt.title('Sea Level Equivalent Projection by Initialization Method')
77+
# plt.legend()
78+
# plt.grid(True)
79+
80+
# plt.savefig('ISEFlow_from_NC.png')
81+
# plt.show()
82+
# plt.close('all')
83+
84+
# init_method = 'da'
85+
# plt.figure(figsize=(10, 6))
86+
# for ocean_sensitivity in ['low', 'medium', 'high']:
87+
# pred, uq = iseflowais.predict(
88+
# year, pr_anomaly, evspsbl_anomaly, mrro_anomaly, smb_anomaly, ts_anomaly,
89+
# ocean_thermal_forcing, ocean_salinity, ocean_temp,
90+
# initial_year, numerics, stress_balance, resolution, init_method,
91+
# melt_in_floating_cells, icefront_migration, ocean_forcing_type,
92+
# ocean_sensitivity, ice_shelf_fracture, open_melt_type, standard_melt_type
93+
# )
94+
# pred = pred.squeeze()
95+
96+
97+
# plt.plot(year, pred, label=ocean_sensitivity)
98+
# plt.xlabel('Year')
99+
# plt.ylabel('Projected Sea Level Equivalent (mm SLE)')
100+
# plt.title('Sea Level Equivalent Projection by Ocean Sensitivity')
101+
# plt.legend()
102+
# plt.grid(True)
103+
104+
# plt.savefig('ISEFlow_from_NC_ocean sensitivity.png')
105+
# plt.show()
106+
107+
# plt.figure(figsize=(10, 6))
108+
# for open_melt_type in ['lin', 'quad', 'pico', 'picop', 'plume', 'nonlocal+slope']:
109+
# pred, uq = iseflowais.predict(
110+
# year, pr_anomaly, evspsbl_anomaly, mrro_anomaly, smb_anomaly, ts_anomaly,
111+
# ocean_thermal_forcing, ocean_salinity, ocean_temp,
112+
# initial_year, numerics, stress_balance, resolution, init_method,
113+
# melt_in_floating_cells, icefront_migration, ocean_forcing_type,
114+
# ocean_sensitivity, ice_shelf_fracture, open_melt_type, standard_melt_type
115+
# )
116+
# pred = pred.squeeze()
117+
118+
119+
# plt.plot(year, pred, label=open_melt_type)
120+
# plt.xlabel('Year')
121+
# plt.ylabel('Projected Sea Level Equivalent (mm SLE)')
122+
# plt.title('Sea Level Equivalent Projection by Open Melt Type')
123+
# plt.legend()
124+
# plt.grid(True)
125+
126+
# plt.savefig('ISEFlow_from_NC_open type.png')
127+
# plt.show()
128+
129+
130+

ise/data/feature_engineer.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,38 @@ def add_model_characteristics(
324324
self._including_model_characteristics = True
325325

326326
return self
327+
328+
329+
def scale_data(data, scaler_path, ):
330+
331+
dropped_columns = [
332+
"id",
333+
"cmip_model",
334+
"pathway",
335+
"exp",
336+
"ice_sheet",
337+
"Scenario",
338+
"Tier",
339+
"aogcm",
340+
"id",
341+
"exp",
342+
"model",
343+
"ivaf",
344+
]
345+
346+
dropped_columns = [x for x in data.columns if x in dropped_columns]
347+
dropped_data = data[dropped_columns]
348+
data = data.drop(
349+
columns=[x for x in data.columns if "sle" in x] + dropped_columns
350+
)
351+
cols = data.columns
327352

353+
scaler = pickle.load(open(scaler_path, "rb"))
354+
scaled = scaler.transform(data)
355+
scaled = pd.DataFrame(scaled, columns=cols,)
356+
if 'outlier' in scaled.columns:
357+
scaled = scaled.drop(columns=['outlier'])
358+
return scaled
328359

329360
def add_model_characteristics(
330361
data, model_char_path=r"./ise/utils/model_characteristics.csv", encode=True, ids_path=None
@@ -393,7 +424,7 @@ def backfill_outliers(data, percentile=99.999):
393424
return data
394425

395426

396-
def add_lag_variables(data: pd.DataFrame, lag: int) -> pd.DataFrame:
427+
def add_lag_variables(data: pd.DataFrame, lag: int, verbose=True) -> pd.DataFrame:
397428
"""
398429
Adds lag variables to the input DataFrame.
399430
@@ -406,24 +437,25 @@ def add_lag_variables(data: pd.DataFrame, lag: int) -> pd.DataFrame:
406437
"""
407438

408439
# Separate columns that won't be lagged and shouldn't be dropped
409-
cols_to_exclude = [
410-
"id",
411-
"cmip_model",
412-
"pathway",
413-
"exp",
414-
"ice_sheet",
415-
"Scenario",
416-
"Ocean forcing",
417-
"Ocean sensitivity",
418-
"Ice shelf fracture",
419-
"Tier",
420-
"aogcm",
421-
"id",
422-
"exp",
423-
"model",
424-
"ivaf",
425-
"sector",
426-
]
440+
# cols_to_exclude = [
441+
# "id",
442+
# "cmip_model",
443+
# "pathway",
444+
# "exp",
445+
# "ice_sheet",
446+
# "Scenario",
447+
# "Ocean forcing",
448+
# "Ocean sensitivity",
449+
# "Ice shelf fracture",
450+
# "Tier",
451+
# "aogcm",
452+
# "id",
453+
# "exp",
454+
# "model",
455+
# "ivaf",
456+
# "sector",
457+
# ]
458+
cols_to_exclude = [x for x in data.columns if x not in ("year", "pr_anomaly", "evspsbl_anomaly", "mrro_anomaly", "smb_anomaly", "ts_anomaly", "thermal_forcing", "salinity", "temperature")]
427459
cols_to_exclude = [x for x in cols_to_exclude if x in data.columns]
428460
temporal_indicator = "time" if "time" in data.columns else "year"
429461
non_temporal_cols = [temporal_indicator] + [
@@ -437,7 +469,11 @@ def add_lag_variables(data: pd.DataFrame, lag: int) -> pd.DataFrame:
437469
# Calculate the number of segments
438470
num_segments = len(data) // projection_length
439471

440-
for segment_idx in tqdm(range(num_segments), total=num_segments, desc="Adding lag variables"):
472+
if verbose:
473+
iterator = tqdm(range(num_segments), total=num_segments, desc="Adding lag variables")
474+
else:
475+
iterator = range(num_segments)
476+
for segment_idx in iterator:
441477
# Extract the segment
442478
segment_start = segment_idx * projection_length
443479
segment_end = (segment_idx + 1) * projection_length

0 commit comments

Comments
 (0)